CUTLASS 3.1 Python interface documentation (#917)
* Add 12.1 Dockerfile * Add 3.1 docs
This commit is contained in:
923
python/docs/_modules/cutlass/emit/pytorch.html
Normal file
923
python/docs/_modules/cutlass/emit/pytorch.html
Normal file
@ -0,0 +1,923 @@
|
||||
<!doctype html>
|
||||
<html class="no-js" lang="en">
|
||||
<head><meta charset="utf-8"/>
|
||||
<meta name="viewport" content="width=device-width,initial-scale=1"/>
|
||||
<meta name="color-scheme" content="light dark"><link rel="index" title="Index" href="../../../genindex.html" /><link rel="search" title="Search" href="../../../search.html" />
|
||||
<link rel="canonical" href="docs/_modules/cutlass/emit/pytorch.html" />
|
||||
|
||||
<!-- Generated with Sphinx 6.1.3 and Furo 2023.03.27 -->
|
||||
<title>cutlass.emit.pytorch - CUTLASS Python</title>
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/pygments.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/styles/furo.css?digest=fad236701ea90a88636c2a8c73b44ae642ed2a53" />
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/copybutton.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/tabs.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/styles/furo-extensions.css?digest=30d1aed668e5c3a91c3e3bf6a60b675221979f0e" />
|
||||
|
||||
|
||||
|
||||
|
||||
<style>
|
||||
body {
|
||||
--color-code-background: #eeffcc;
|
||||
--color-code-foreground: black;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
@media not print {
|
||||
body[data-theme="dark"] {
|
||||
--color-code-background: #272822;
|
||||
--color-code-foreground: #f8f8f2;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
@media (prefers-color-scheme: dark) {
|
||||
body:not([data-theme="light"]) {
|
||||
--color-code-background: #272822;
|
||||
--color-code-foreground: #f8f8f2;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
</style></head>
|
||||
<body>
|
||||
|
||||
<script>
|
||||
document.body.dataset.theme = localStorage.getItem("theme") || "auto";
|
||||
</script>
|
||||
|
||||
|
||||
<svg xmlns="http://www.w3.org/2000/svg" style="display: none;">
|
||||
<symbol id="svg-toc" viewBox="0 0 24 24">
|
||||
<title>Contents</title>
|
||||
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 1024 1024">
|
||||
<path d="M408 442h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8zm-8 204c0 4.4 3.6 8 8 8h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56zm504-486H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zm0 632H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zM115.4 518.9L271.7 642c5.8 4.6 14.4.5 14.4-6.9V388.9c0-7.4-8.5-11.5-14.4-6.9L115.4 505.1a8.74 8.74 0 0 0 0 13.8z"/>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-menu" viewBox="0 0 24 24">
|
||||
<title>Menu</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-menu">
|
||||
<line x1="3" y1="12" x2="21" y2="12"></line>
|
||||
<line x1="3" y1="6" x2="21" y2="6"></line>
|
||||
<line x1="3" y1="18" x2="21" y2="18"></line>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-arrow-right" viewBox="0 0 24 24">
|
||||
<title>Expand</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-chevron-right">
|
||||
<polyline points="9 18 15 12 9 6"></polyline>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-sun" viewBox="0 0 24 24">
|
||||
<title>Light mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather-sun">
|
||||
<circle cx="12" cy="12" r="5"></circle>
|
||||
<line x1="12" y1="1" x2="12" y2="3"></line>
|
||||
<line x1="12" y1="21" x2="12" y2="23"></line>
|
||||
<line x1="4.22" y1="4.22" x2="5.64" y2="5.64"></line>
|
||||
<line x1="18.36" y1="18.36" x2="19.78" y2="19.78"></line>
|
||||
<line x1="1" y1="12" x2="3" y2="12"></line>
|
||||
<line x1="21" y1="12" x2="23" y2="12"></line>
|
||||
<line x1="4.22" y1="19.78" x2="5.64" y2="18.36"></line>
|
||||
<line x1="18.36" y1="5.64" x2="19.78" y2="4.22"></line>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-moon" viewBox="0 0 24 24">
|
||||
<title>Dark mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-moon">
|
||||
<path stroke="none" d="M0 0h24v24H0z" fill="none" />
|
||||
<path d="M12 3c.132 0 .263 0 .393 0a7.5 7.5 0 0 0 7.92 12.446a9 9 0 1 1 -8.313 -12.454z" />
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-sun-half" viewBox="0 0 24 24">
|
||||
<title>Auto light/dark mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-shadow">
|
||||
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
|
||||
<circle cx="12" cy="12" r="9" />
|
||||
<path d="M13 12h5" />
|
||||
<path d="M13 15h4" />
|
||||
<path d="M13 18h1" />
|
||||
<path d="M13 9h4" />
|
||||
<path d="M13 6h1" />
|
||||
</svg>
|
||||
</symbol>
|
||||
</svg>
|
||||
|
||||
<input type="checkbox" class="sidebar-toggle" name="__navigation" id="__navigation">
|
||||
<input type="checkbox" class="sidebar-toggle" name="__toc" id="__toc">
|
||||
<label class="overlay sidebar-overlay" for="__navigation">
|
||||
<div class="visually-hidden">Hide navigation sidebar</div>
|
||||
</label>
|
||||
<label class="overlay toc-overlay" for="__toc">
|
||||
<div class="visually-hidden">Hide table of contents sidebar</div>
|
||||
</label>
|
||||
|
||||
|
||||
|
||||
<div class="page">
|
||||
<header class="mobile-header">
|
||||
<div class="header-left">
|
||||
<label class="nav-overlay-icon" for="__navigation">
|
||||
<div class="visually-hidden">Toggle site navigation sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-menu"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
<div class="header-center">
|
||||
<a href="../../../index.html"><div class="brand">CUTLASS Python</div></a>
|
||||
</div>
|
||||
<div class="header-right">
|
||||
<div class="theme-toggle-container theme-toggle-header">
|
||||
<button class="theme-toggle">
|
||||
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
|
||||
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
|
||||
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
|
||||
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
|
||||
</button>
|
||||
</div>
|
||||
<label class="toc-overlay-icon toc-header-icon no-toc" for="__toc">
|
||||
<div class="visually-hidden">Toggle table of contents sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
</header>
|
||||
<aside class="sidebar-drawer">
|
||||
<div class="sidebar-container">
|
||||
|
||||
<div class="sidebar-sticky"><a class="sidebar-brand" href="../../../index.html">
|
||||
|
||||
<div class="sidebar-logo-container">
|
||||
<img class="sidebar-logo only-light" src="../../../_static/cutlass-logo-small.png" alt="Light Logo"/>
|
||||
<img class="sidebar-logo only-dark" src="../../../_static/cutlass-logo-small.png" alt="Dark Logo"/>
|
||||
</div>
|
||||
|
||||
<span class="sidebar-brand-text">CUTLASS Python</span>
|
||||
|
||||
</a><form class="sidebar-search-container" method="get" action="../../../search.html" role="search">
|
||||
<input class="sidebar-search" placeholder="Search" name="q" aria-label="Search">
|
||||
<input type="hidden" name="check_keywords" value="yes">
|
||||
<input type="hidden" name="area" value="default">
|
||||
</form>
|
||||
<div id="searchbox"></div><div class="sidebar-scroll"><div class="sidebar-tree">
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../index.html">Home</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../install.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../externals/00_basic_gemm.html">Getting Started</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contribute.html">Contributing</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python Documentation:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../modules.html">CUTLASS Python API</a><input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" role="switch" type="checkbox"/><label for="toctree-checkbox-1"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l2 has-children"><a class="reference internal" href="../../../cutlass.html">CUTLASS</a><input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" role="switch" type="checkbox"/><label for="toctree-checkbox-2"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../../../cutlass.emit.html">Emitters</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../../../cutlass.op.html">Operations</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../../../cutlass.utils.html">Utilities</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Examples and Tutorials:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../examples.html">Examples</a><input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" role="switch" type="checkbox"/><label for="toctree-checkbox-3"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../../../externals/00_basic_gemm.html">Basic GEMM</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../../../externals/01_epilogue.html">Epilogue</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../../../externals/02_pytorch_extension_grouped_gemm.html">PyTorch Extension</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Reference:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/cutlass">Github</a></li>
|
||||
</ul>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</aside>
|
||||
<div class="main">
|
||||
<div class="content">
|
||||
<div class="article-container">
|
||||
<a href="#" class="back-to-top muted-link">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24">
|
||||
<path d="M13 20h-2V8l-5.5 5.5-1.42-1.42L12 4.16l7.92 7.92-1.42 1.42L13 8v12z"></path>
|
||||
</svg>
|
||||
<span>Back to top</span>
|
||||
</a>
|
||||
<div class="content-icon-container">
|
||||
<div class="theme-toggle-container theme-toggle-content">
|
||||
<button class="theme-toggle">
|
||||
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
|
||||
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
|
||||
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
|
||||
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
|
||||
</button>
|
||||
</div>
|
||||
<label class="toc-overlay-icon toc-content-icon no-toc" for="__toc">
|
||||
<div class="visually-hidden">Toggle table of contents sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
<article role="main">
|
||||
<h1>Source code for cutlass.emit.pytorch</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#################################################################################################</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.</span>
|
||||
<span class="c1"># SPDX-License-Identifier: BSD-3-Clause</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Redistribution and use in source and binary forms, with or without</span>
|
||||
<span class="c1"># modification, are permitted provided that the following conditions are met:</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># 1. Redistributions of source code must retain the above copyright notice, this</span>
|
||||
<span class="c1"># list of conditions and the following disclaimer.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># 2. Redistributions in binary form must reproduce the above copyright notice,</span>
|
||||
<span class="c1"># this list of conditions and the following disclaimer in the documentation</span>
|
||||
<span class="c1"># and/or other materials provided with the distribution.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># 3. Neither the name of the copyright holder nor the names of its</span>
|
||||
<span class="c1"># contributors may be used to endorse or promote products derived from</span>
|
||||
<span class="c1"># this software without specific prior written permission.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"</span>
|
||||
<span class="c1"># AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE</span>
|
||||
<span class="c1"># IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE</span>
|
||||
<span class="c1"># DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE</span>
|
||||
<span class="c1"># FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL</span>
|
||||
<span class="c1"># DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR</span>
|
||||
<span class="c1"># SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER</span>
|
||||
<span class="c1"># CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,</span>
|
||||
<span class="c1"># OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE</span>
|
||||
<span class="c1"># OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1">#################################################################################################</span>
|
||||
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd">Utilities for generating source for building a PyTorch CUDA extension that using a CUTLASS kernel.</span>
|
||||
<span class="sd">If specified, the extension can be JIT compiled via PyTorch's ``cpp_extension.load`` method.</span>
|
||||
|
||||
<span class="sd">Example usage with JIT compilation:</span>
|
||||
|
||||
<span class="sd">.. highlight:: python</span>
|
||||
<span class="sd">.. code-block:: python</span>
|
||||
|
||||
<span class="sd"> plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass.LayoutType.RowMajor)</span>
|
||||
<span class="sd"> op = plan.construct()</span>
|
||||
<span class="sd"> mod = cutlass.emit.pytorch(op, 'cutlass_gemm', 80, jit=True)</span>
|
||||
|
||||
<span class="sd"> # Generate inputs for the GEMM</span>
|
||||
<span class="sd"> A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)]</span>
|
||||
|
||||
<span class="sd"> # Run the module</span>
|
||||
<span class="sd"> D = mod.run(A, B, C)</span>
|
||||
|
||||
|
||||
<span class="sd">Example usage without JIT compilation:</span>
|
||||
|
||||
<span class="sd">.. highlight:: python</span>
|
||||
<span class="sd">.. code-block:: python</span>
|
||||
|
||||
<span class="sd"> plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass.LayoutType.RowMajor)</span>
|
||||
<span class="sd"> op = plan.construct()</span>
|
||||
<span class="sd"> cutlass.emit.pytorch(op, 'cutlass_gemm', 80, jit=False, sourcedir='output')</span>
|
||||
|
||||
<span class="sd">After this call, the directory ``output`` contains ``setup.py``,</span>
|
||||
<span class="sd">``cutlass_gemm.cpp``, and ``cutlass_gemm_kernel.cu``. The module can be built from</span>
|
||||
<span class="sd">within ``output`` by running: ``TORCH_CUDA_ARCH_LIST="8.0" python setup.py develop --user``.</span>
|
||||
|
||||
<span class="sd">The module can later be used in Python via:</span>
|
||||
|
||||
<span class="sd">.. highlight:: python</span>
|
||||
<span class="sd">.. code-block:: python</span>
|
||||
|
||||
<span class="sd"> import torch</span>
|
||||
<span class="sd"> import cutlass_gemm</span>
|
||||
|
||||
<span class="sd"> # Generate inputs for the GEMM</span>
|
||||
<span class="sd"> A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)]</span>
|
||||
|
||||
<span class="sd"> # Run the module</span>
|
||||
<span class="sd"> D = cutlass_gemm.run(A, B, C)</span>
|
||||
<span class="sd">"""</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">logging</span>
|
||||
<span class="kn">import</span> <span class="nn">os</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">cutlass_bindings</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">cutlass</span> <span class="kn">import</span> <span class="n">CUTLASS_PATH</span><span class="p">,</span> <span class="n">logger</span><span class="p">,</span> <span class="n">swizzle</span>
|
||||
<span class="kn">from</span> <span class="nn">cutlass.backend.gemm_operation</span> <span class="kn">import</span> <span class="n">GemmOperationGrouped</span><span class="p">,</span> <span class="n">GemmOperationUniversal</span>
|
||||
<span class="kn">from</span> <span class="nn">cutlass.backend.library</span> <span class="kn">import</span> <span class="n">ApiVersion</span>
|
||||
<span class="kn">from</span> <span class="nn">cutlass.backend.utils.software</span> <span class="kn">import</span> <span class="n">CheckPackages</span><span class="p">,</span> <span class="n">SubstituteTemplate</span>
|
||||
<span class="kn">from</span> <span class="nn">cutlass.emit</span> <span class="kn">import</span> <span class="n">common</span>
|
||||
|
||||
<span class="n">torch_available</span> <span class="o">=</span> <span class="n">CheckPackages</span><span class="p">()</span><span class="o">.</span><span class="n">check_torch</span><span class="p">()</span>
|
||||
<span class="k">if</span> <span class="n">torch_available</span><span class="p">:</span>
|
||||
<span class="kn">import</span> <span class="nn">torch</span>
|
||||
|
||||
|
||||
<span class="n">_PYTORCH_CUDA_TEMPLATE</span> <span class="o">=</span> <span class="n">common</span><span class="o">.</span><span class="n">_CSTYLE_AUTOGEN_COMMENT</span> <span class="o">+</span> <span class="s2">"""</span>
|
||||
<span class="s2">#include <torch/extension.h></span>
|
||||
<span class="s2">#include <ATen/ATen.h></span>
|
||||
|
||||
<span class="s2">#include "cutlass/cutlass.h"</span>
|
||||
<span class="s2">#include "cutlass/util/device_memory.h"</span>
|
||||
|
||||
<span class="s2">$</span><span class="si">{includes}</span>
|
||||
<span class="s2">$</span><span class="si">{declaration}</span>
|
||||
<span class="s2">$</span><span class="si">{impl}</span>
|
||||
<span class="s2">"""</span>
|
||||
|
||||
<span class="n">_PYTORCH_GEMM_CPP_TEMPLATE</span> <span class="o">=</span> <span class="n">common</span><span class="o">.</span><span class="n">_CSTYLE_AUTOGEN_COMMENT</span> <span class="o">+</span> <span class="s2">"""</span>
|
||||
<span class="s2">#include <torch/extension.h></span>
|
||||
<span class="s2">#include <ATen/ATen.h></span>
|
||||
<span class="s2">#include <pybind11/stl.h></span>
|
||||
|
||||
<span class="s2">// CUDA forward declarations</span>
|
||||
<span class="s2">at::Tensor $</span><span class="si">{name}</span><span class="s2">_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, float alpha=1.f, float beta=0.f);</span>
|
||||
|
||||
<span class="s2">// C++ interface</span>
|
||||
<span class="s2">at::Tensor $</span><span class="si">{name}</span><span class="s2">(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, float alpha=1.f, float beta=0.f) {</span>
|
||||
<span class="s2"> return $</span><span class="si">{name}</span><span class="s2">_kernel(A, B, C, alpha, beta);</span>
|
||||
<span class="s2">}</span>
|
||||
|
||||
<span class="s2">PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {</span>
|
||||
<span class="s2"> m.def("run", py::overload_cast<const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>, float, float>(&$</span><span class="si">{name}</span><span class="s2">), py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f);</span>
|
||||
<span class="s2">}</span>
|
||||
<span class="s2">"""</span>
|
||||
|
||||
<span class="n">_PYTORCH_GROUPED_GEMM_CPP_TEMPLATE</span> <span class="o">=</span> <span class="n">common</span><span class="o">.</span><span class="n">_CSTYLE_AUTOGEN_COMMENT</span> <span class="o">+</span> <span class="s2">"""</span>
|
||||
<span class="s2">#include <torch/extension.h></span>
|
||||
<span class="s2">#include <ATen/ATen.h></span>
|
||||
<span class="s2">#include <pybind11/stl.h></span>
|
||||
|
||||
<span class="s2">// CUDA forward declarations</span>
|
||||
<span class="s2">std::vector<at::Tensor> $</span><span class="si">{name}</span><span class="s2">_kernel(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C=at::nullopt, float alpha=1.f, float beta=0.f);</span>
|
||||
|
||||
<span class="s2">// C++ interface</span>
|
||||
<span class="s2">std::vector<at::Tensor> $</span><span class="si">{name}</span><span class="s2">(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C=at::nullopt, float alpha=1.f, float beta=0.f) {</span>
|
||||
<span class="s2"> return $</span><span class="si">{name}</span><span class="s2">_kernel(A, B, C, alpha, beta);</span>
|
||||
<span class="s2">}</span>
|
||||
|
||||
<span class="s2">PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {</span>
|
||||
<span class="s2"> m.def("run", py::overload_cast<const std::vector<at::Tensor>&, const std::vector<at::Tensor>&, at::optional<const std::vector<at::Tensor>>, float, float>(&$</span><span class="si">{name}</span><span class="s2">),</span>
|
||||
<span class="s2"> py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f);</span>
|
||||
<span class="s2">}</span>
|
||||
<span class="s2">"""</span>
|
||||
|
||||
<span class="n">_PYTORCH_GEMM_INCLUDES</span> <span class="o">=</span> <span class="p">{</span>
|
||||
<span class="n">ApiVersion</span><span class="o">.</span><span class="n">v2x</span><span class="p">:</span> <span class="s2">"""</span>
|
||||
<span class="s2">#include "cutlass/gemm/device/gemm_universal.h"</span>
|
||||
<span class="s2">"""</span><span class="p">,</span>
|
||||
<span class="n">ApiVersion</span><span class="o">.</span><span class="n">v3x</span><span class="p">:</span> <span class="s2">"""</span>
|
||||
<span class="s2">#include "cutlass/gemm/device/gemm_universal_adapter.h"</span>
|
||||
<span class="s2">#include "cutlass/gemm/collective/collective_builder.hpp"</span>
|
||||
<span class="s2">#include "cutlass/gemm/device/gemm_universal_adapter.h"</span>
|
||||
<span class="s2">#include "cutlass/gemm/kernel/gemm_universal.hpp"</span>
|
||||
<span class="s2">#include "cutlass/epilogue/collective/default_epilogue.hpp"</span>
|
||||
<span class="s2">#include "cutlass/util/packed_stride.hpp"</span>
|
||||
<span class="s2">"""</span><span class="p">,</span>
|
||||
<span class="p">}</span>
|
||||
|
||||
<span class="n">_PYTORCH_GROUPED_GEMM_INCLUDES</span> <span class="o">=</span> <span class="s2">"""</span>
|
||||
<span class="s2">#include "cutlass/gemm/kernel/default_gemm_grouped.h"</span>
|
||||
<span class="s2">#include "cutlass/gemm/device/gemm_grouped.h"</span>
|
||||
<span class="s2">"""</span>
|
||||
|
||||
<span class="n">_CUTLASS_TYPE_TO_TORCH_TYPE</span> <span class="o">=</span> <span class="p">{</span>
|
||||
<span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">float16</span><span class="p">:</span> <span class="s2">"torch::kF16"</span><span class="p">,</span>
|
||||
<span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">float32</span><span class="p">:</span> <span class="s2">"torch::kF32"</span><span class="p">,</span>
|
||||
<span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">float64</span><span class="p">:</span> <span class="s2">"torch::kF64"</span><span class="p">,</span>
|
||||
<span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">int8</span><span class="p">:</span> <span class="s2">"torch::I8"</span><span class="p">,</span>
|
||||
<span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">int32</span><span class="p">:</span> <span class="s2">"torch::I32"</span><span class="p">,</span>
|
||||
<span class="p">}</span>
|
||||
|
||||
<span class="n">_PYTORCH_GEMM_IMPL_TEMPLATE_2x</span> <span class="o">=</span> <span class="p">(</span>
|
||||
<span class="n">common</span><span class="o">.</span><span class="n">_CUTLASS_KERNEL_RUN_GEMM_2x</span>
|
||||
<span class="o">+</span> <span class="s2">"""</span>
|
||||
<span class="s2">at::Tensor $</span><span class="si">{name}</span><span class="s2">_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C, float alpha, float beta) {</span>
|
||||
<span class="s2"> int M = A.size(0);</span>
|
||||
<span class="s2"> int N = B.size(1);</span>
|
||||
<span class="s2"> int K = A.size(1);</span>
|
||||
|
||||
<span class="s2"> typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ?</span>
|
||||
<span class="s2"> nullptr :</span>
|
||||
<span class="s2"> reinterpret_cast<typename DeviceKernel::ElementC*>(C->contiguous().data_ptr());</span>
|
||||
<span class="s2"> at::Tensor D = B.new_empty({M, N}, $</span><span class="si">{torch_type_C}</span><span class="s2">);</span>
|
||||
|
||||
<span class="s2"> cutlass::Status status = $</span><span class="si">{name}</span><span class="s2">_kernel_run(M, N, K,</span>
|
||||
<span class="s2"> reinterpret_cast<typename DeviceKernel::ElementA*>(A.contiguous().data_ptr()),</span>
|
||||
<span class="s2"> reinterpret_cast<typename DeviceKernel::ElementB*>(B.contiguous().data_ptr()),</span>
|
||||
<span class="s2"> ptrC,</span>
|
||||
<span class="s2"> reinterpret_cast<typename DeviceKernel::ElementC*>(D.contiguous().data_ptr()),</span>
|
||||
<span class="s2"> ElementCompute(alpha), ElementCompute(beta));</span>
|
||||
|
||||
<span class="s2"> TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");</span>
|
||||
<span class="s2"> return D;</span>
|
||||
<span class="s2">}</span>
|
||||
<span class="s2">"""</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="n">_PYTORCH_GEMM_IMPL_TEMPLATE_3x</span> <span class="o">=</span> <span class="p">(</span>
|
||||
<span class="n">common</span><span class="o">.</span><span class="n">_CUTLASS_KERNEL_RUN_GEMM_3x</span>
|
||||
<span class="o">+</span> <span class="s2">"""</span>
|
||||
<span class="s2">bool hw_info_queried = false;</span>
|
||||
<span class="s2">cutlass::KernelHardwareInfo hw_info;</span>
|
||||
|
||||
<span class="s2">at::Tensor $</span><span class="si">{name}</span><span class="s2">_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C, float alpha, float beta) {</span>
|
||||
<span class="s2"> int M = A.size(0);</span>
|
||||
<span class="s2"> int N = B.size(1);</span>
|
||||
<span class="s2"> int K = A.size(1);</span>
|
||||
<span class="s2"> int L = 1;</span>
|
||||
|
||||
<span class="s2"> // Query hardware info if we haven't already</span>
|
||||
<span class="s2"> if (!hw_info_queried) {</span>
|
||||
<span class="s2"> hw_info.device_id = 0;</span>
|
||||
<span class="s2"> hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);</span>
|
||||
<span class="s2"> }</span>
|
||||
|
||||
<span class="s2"> typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ?</span>
|
||||
<span class="s2"> nullptr :</span>
|
||||
<span class="s2"> reinterpret_cast<typename DeviceKernel::ElementC*>(C->contiguous().data_ptr());</span>
|
||||
<span class="s2"> at::Tensor D = B.new_empty({M, N}, $</span><span class="si">{torch_type_C}</span><span class="s2">);</span>
|
||||
|
||||
<span class="s2"> cutlass::Status status = $</span><span class="si">{name}</span><span class="s2">_kernel_run(M, N, K, L,</span>
|
||||
<span class="s2"> reinterpret_cast<typename DeviceKernel::ElementA*>(A.contiguous().data_ptr()),</span>
|
||||
<span class="s2"> reinterpret_cast<typename DeviceKernel::ElementB*>(B.contiguous().data_ptr()),</span>
|
||||
<span class="s2"> ptrC,</span>
|
||||
<span class="s2"> reinterpret_cast<typename DeviceKernel::ElementC*>(D.contiguous().data_ptr()),</span>
|
||||
<span class="s2"> ElementCompute(alpha), ElementCompute(beta),</span>
|
||||
<span class="s2"> hw_info);</span>
|
||||
|
||||
<span class="s2"> TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");</span>
|
||||
<span class="s2"> return D;</span>
|
||||
<span class="s2">}</span>
|
||||
<span class="s2">"""</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
|
||||
<span class="n">_PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE</span> <span class="o">=</span> <span class="p">(</span>
|
||||
<span class="n">common</span><span class="o">.</span><span class="n">_CUTLASS_KERNEL_RUN_GROUPED_GEMM_2x</span>
|
||||
<span class="o">+</span> <span class="s2">"""</span>
|
||||
<span class="s2">std::vector<at::Tensor> $</span><span class="si">{name}</span><span class="s2">_kernel(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C, float alpha, float beta) {</span>
|
||||
<span class="s2"> size_t num = A.size();</span>
|
||||
|
||||
<span class="s2"> // To avoid performing many small cudaMallocs and host-to-device copies,</span>
|
||||
<span class="s2"> // we serialize the grouped GEMM arguments on the host, allocate one</span>
|
||||
<span class="s2"> // large chunk of device memory, and perform a single cudaMemcpy to</span>
|
||||
<span class="s2"> // copy the host data to the device. Allocation overheads could be</span>
|
||||
<span class="s2"> // avoided by using a memory pool.</span>
|
||||
|
||||
<span class="s2"> // Calculate the total size of the data to be copied from host to device</span>
|
||||
<span class="s2"> size_t total_size = sizeof(cutlass::gemm::GemmCoord) +</span>
|
||||
<span class="s2"> sizeof(DeviceKernel::ElementA*) +</span>
|
||||
<span class="s2"> sizeof(DeviceKernel::ElementB*) +</span>
|
||||
<span class="s2"> sizeof(DeviceKernel::ElementC*) +</span>
|
||||
<span class="s2"> sizeof(DeviceKernel::ElementC*) +</span>
|
||||
<span class="s2"> sizeof(int64_t) +</span>
|
||||
<span class="s2"> sizeof(int64_t) +</span>
|
||||
<span class="s2"> sizeof(int64_t);</span>
|
||||
<span class="s2"> total_size *= num;</span>
|
||||
|
||||
<span class="s2"> // num * sizeof(cutlass::gemm::GemmCoord) may leave one at a non-multiple</span>
|
||||
<span class="s2"> // of sizeof(DeviceKernel::ElementA*) (which will be 64 on a 64-bit system).</span>
|
||||
<span class="s2"> // To ensure that we don't end up having misaligned loads in the kernel,</span>
|
||||
<span class="s2"> // we pad to the nearest multiple of 8.</span>
|
||||
<span class="s2"> //</span>
|
||||
<span class="s2"> // Note that, even on a 32-bit system (for which sizeof(X*) will not equal</span>
|
||||
<span class="s2"> // sizeof(int64_t)), only padding between the list of GemmCoords and the</span>
|
||||
<span class="s2"> // list of ptr_As is sufficient because the set of four equal-length lists of pointers</span>
|
||||
<span class="s2"> // (A*, B*, C*, D*) will ensure that the first list of int64_ts will always</span>
|
||||
<span class="s2"> // start on a multiple of 8.</span>
|
||||
<span class="s2"> int64_t padding = 8 - (total_size % 8);</span>
|
||||
<span class="s2"> total_size += padding;</span>
|
||||
|
||||
<span class="s2"> uint8_t* host_data = new uint8_t[total_size];</span>
|
||||
<span class="s2"> cutlass::DeviceAllocation<uint8_t> device_data(total_size);</span>
|
||||
|
||||
<span class="s2"> uint8_t* start = host_data;</span>
|
||||
<span class="s2"> cutlass::gemm::GemmCoord* problem_sizes_host = reinterpret_cast<cutlass::gemm::GemmCoord*>(start);</span>
|
||||
|
||||
<span class="s2"> // Apply the padding after the list of GemmCoords</span>
|
||||
<span class="s2"> start += num * sizeof(cutlass::gemm::GemmCoord) + padding;</span>
|
||||
|
||||
<span class="s2"> int64_t ptr_A_offset = start - host_data;</span>
|
||||
<span class="s2"> DeviceKernel::ElementA** ptr_A_host = reinterpret_cast<DeviceKernel::ElementA**>(start);</span>
|
||||
<span class="s2"> start += num * sizeof(DeviceKernel::ElementA*);</span>
|
||||
|
||||
<span class="s2"> int64_t ptr_B_offset = start - host_data;</span>
|
||||
<span class="s2"> DeviceKernel::ElementB** ptr_B_host = reinterpret_cast<DeviceKernel::ElementB**>(start);</span>
|
||||
<span class="s2"> start += num * sizeof(DeviceKernel::ElementB*);</span>
|
||||
|
||||
<span class="s2"> int64_t ptr_C_offset = start - host_data;</span>
|
||||
<span class="s2"> DeviceKernel::ElementC** ptr_C_host = reinterpret_cast<DeviceKernel::ElementC**>(start);</span>
|
||||
<span class="s2"> start += num * sizeof(DeviceKernel::ElementC*);</span>
|
||||
|
||||
<span class="s2"> int64_t ptr_D_offset = start - host_data;</span>
|
||||
<span class="s2"> DeviceKernel::ElementC** ptr_D_host = reinterpret_cast<DeviceKernel::ElementC**>(start);</span>
|
||||
<span class="s2"> start += num * sizeof(DeviceKernel::ElementC*);</span>
|
||||
|
||||
<span class="s2"> int64_t lda_offset = start - host_data;</span>
|
||||
<span class="s2"> int64_t* lda_host = reinterpret_cast<int64_t*>(start);</span>
|
||||
<span class="s2"> start += num * sizeof(int64_t);</span>
|
||||
|
||||
<span class="s2"> int64_t ldb_offset = start - host_data;</span>
|
||||
<span class="s2"> int64_t* ldb_host = reinterpret_cast<int64_t*>(start);</span>
|
||||
<span class="s2"> start += num * sizeof(int64_t);</span>
|
||||
|
||||
<span class="s2"> int64_t ldc_offset = start - host_data;</span>
|
||||
<span class="s2"> int64_t* ldc_host = reinterpret_cast<int64_t*>(start);</span>
|
||||
<span class="s2"> start += num * sizeof(int64_t);</span>
|
||||
|
||||
<span class="s2"> std::vector<at::Tensor> D(num);</span>
|
||||
|
||||
<span class="s2"> bool need_C = (C != at::nullopt) && (beta != 0.f);</span>
|
||||
<span class="s2"> for (size_t i = 0; i < num; ++i) {</span>
|
||||
<span class="s2"> int M = A[i].size(0);</span>
|
||||
<span class="s2"> int N = B[i].size(1);</span>
|
||||
<span class="s2"> int K = A[i].size(1);</span>
|
||||
<span class="s2"> *(problem_sizes_host + i) = {M, N, K};</span>
|
||||
<span class="s2"> *(ptr_A_host + i) = reinterpret_cast<typename DeviceKernel::ElementA*>(A[i].contiguous().data_ptr());</span>
|
||||
<span class="s2"> *(ptr_B_host + i) = reinterpret_cast<typename DeviceKernel::ElementB*>(B[i].contiguous().data_ptr());</span>
|
||||
|
||||
<span class="s2"> if (need_C) {</span>
|
||||
<span class="s2"> *(ptr_C_host + i) = reinterpret_cast<typename DeviceKernel::ElementC*>(C->at(i).contiguous().data_ptr());</span>
|
||||
<span class="s2"> }</span>
|
||||
<span class="s2"> else {</span>
|
||||
<span class="s2"> *(ptr_C_host + i) = nullptr;</span>
|
||||
<span class="s2"> }</span>
|
||||
|
||||
<span class="s2"> D[i] = B[i].new_empty({M, N}, $</span><span class="si">{torch_type_C}</span><span class="s2">);</span>
|
||||
<span class="s2"> *(ptr_D_host + i) = reinterpret_cast<typename DeviceKernel::ElementC*>(D[i].contiguous().data_ptr());</span>
|
||||
|
||||
<span class="s2"> *(lda_host + i) = DeviceKernel::LayoutA::packed({M, K}).stride(0);</span>
|
||||
<span class="s2"> *(ldb_host + i) = DeviceKernel::LayoutB::packed({K, N}).stride(0);</span>
|
||||
<span class="s2"> *(ldc_host + i) = DeviceKernel::LayoutC::packed({M, N}).stride(0);</span>
|
||||
<span class="s2"> }</span>
|
||||
|
||||
<span class="s2"> device_data.copy_from_host(host_data);</span>
|
||||
|
||||
<span class="s2"> cutlass::Status status = $</span><span class="si">{name}</span><span class="s2">_kernel_run(</span>
|
||||
<span class="s2"> num,</span>
|
||||
<span class="s2"> reinterpret_cast<cutlass::gemm::GemmCoord*>(device_data.get()),</span>
|
||||
<span class="s2"> reinterpret_cast<DeviceKernel::ElementA**>(device_data.get() + ptr_A_offset),</span>
|
||||
<span class="s2"> reinterpret_cast<DeviceKernel::ElementB**>(device_data.get() + ptr_B_offset),</span>
|
||||
<span class="s2"> reinterpret_cast<DeviceKernel::ElementC**>(device_data.get() + ptr_C_offset),</span>
|
||||
<span class="s2"> reinterpret_cast<DeviceKernel::ElementC**>(device_data.get() + ptr_D_offset),</span>
|
||||
<span class="s2"> reinterpret_cast<int64_t*>(device_data.get() + lda_offset),</span>
|
||||
<span class="s2"> reinterpret_cast<int64_t*>(device_data.get() + ldb_offset),</span>
|
||||
<span class="s2"> reinterpret_cast<int64_t*>(device_data.get() + ldc_offset),</span>
|
||||
<span class="s2"> reinterpret_cast<int64_t*>(device_data.get() + ldc_offset),</span>
|
||||
<span class="s2"> ElementCompute(alpha), ElementCompute(beta));</span>
|
||||
|
||||
<span class="s2"> delete[] host_data;</span>
|
||||
|
||||
<span class="s2"> TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");</span>
|
||||
<span class="s2"> return D;</span>
|
||||
<span class="s2">}</span>
|
||||
<span class="s2">"""</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
|
||||
<span class="n">_PYTORCH_SETUP_PY</span> <span class="o">=</span> <span class="n">common</span><span class="o">.</span><span class="n">_PYSTYLE_AUTOGEN_COMMENT</span> <span class="o">+</span> <span class="s2">"""</span>
|
||||
<span class="s2">from setuptools import setup</span>
|
||||
<span class="s2">from torch.utils.cpp_extension import BuildExtension, CUDAExtension</span>
|
||||
|
||||
<span class="s2">setup(</span>
|
||||
<span class="s2"> name='$</span><span class="si">{name}</span><span class="s2">',</span>
|
||||
<span class="s2"> ext_modules=[</span>
|
||||
<span class="s2"> CUDAExtension('$</span><span class="si">{name}</span><span class="s2">', [</span>
|
||||
<span class="s2"> '$</span><span class="si">{name}</span><span class="s2">.cpp',</span>
|
||||
<span class="s2"> '$</span><span class="si">{name}</span><span class="s2">_kernel.cu',</span>
|
||||
<span class="s2"> ],</span>
|
||||
<span class="s2"> include_dirs=['$</span><span class="si">{cutlass_path}</span><span class="s2">/include', '$</span><span class="si">{cutlass_path}</span><span class="s2">/tools/util/include'],</span>
|
||||
<span class="s2"> extra_compile_args=['-std=c++17']</span>
|
||||
<span class="s2"> ),</span>
|
||||
<span class="s2"> ],</span>
|
||||
<span class="s2"> cmdclass={</span>
|
||||
<span class="s2"> 'build_ext': BuildExtension</span>
|
||||
<span class="s2"> })</span>
|
||||
|
||||
<span class="s2">"""</span>
|
||||
|
||||
|
||||
<span class="k">def</span> <span class="nf">_generate_setup</span><span class="p">(</span><span class="n">name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">sourcedir</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Generates a setup.py file for the extension</span>
|
||||
|
||||
<span class="sd"> :param name: name of the module to generate</span>
|
||||
<span class="sd"> :type name: str</span>
|
||||
<span class="sd"> :param sourcedir: directory to which generated source files should be written</span>
|
||||
<span class="sd"> :type sourcedir: str</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="n">setup_py_file</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">sourcedir</span><span class="p">,</span> <span class="s2">"setup.py"</span><span class="p">)</span>
|
||||
<span class="n">setup_source</span> <span class="o">=</span> <span class="n">SubstituteTemplate</span><span class="p">(</span>
|
||||
<span class="n">_PYTORCH_SETUP_PY</span><span class="p">,</span> <span class="p">{</span><span class="s2">"name"</span><span class="p">:</span> <span class="n">name</span><span class="p">,</span> <span class="s2">"cutlass_path"</span><span class="p">:</span> <span class="n">CUTLASS_PATH</span><span class="p">}</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">setup_py_file</span><span class="p">,</span> <span class="s2">"w"</span><span class="p">)</span> <span class="k">as</span> <span class="n">outfile</span><span class="p">:</span>
|
||||
<span class="n">outfile</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">setup_source</span><span class="p">)</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">_ArchListSetter</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Utility context manager for temporarily setting the value of the ``TORCH_CUDA_ARCH_LIST``</span>
|
||||
<span class="sd"> environment variable when building a PyTorch CUDA module.</span>
|
||||
|
||||
<span class="sd"> ``TORCH_CUDA_ARCH_LIST`` is a space-delmited list of compute capabilites for which a PyTorch</span>
|
||||
<span class="sd"> CUDA module should be compiled.</span>
|
||||
|
||||
<span class="sd"> For example, ``TORCH_CUDA_ARCH_LIST="7.0 8.0"`` would result in the inclusion of</span>
|
||||
<span class="sd"> ``-gencode=arch=compute_70,code=sm_70`` and ``-gencode=arch=compute_80,code=sm_80`` in the</span>
|
||||
<span class="sd"> compilation of the module.</span>
|
||||
|
||||
<span class="sd"> This utility wraps the building of a PyTorch CUDA module with a setting of this environment</span>
|
||||
<span class="sd"> variable according to the current compute capability being targetted.</span>
|
||||
|
||||
<span class="sd"> Example usage:</span>
|
||||
|
||||
<span class="sd"> .. highlight:: python</span>
|
||||
<span class="sd"> .. code-block:: python</span>
|
||||
|
||||
<span class="sd"> # Temporarily set TORCH_CUDA_ARCH_LIST="8.0"</span>
|
||||
<span class="sd"> with _ArchListSetter(80):</span>
|
||||
<span class="sd"> # Perform JIT compilation and loading of the module</span>
|
||||
<span class="sd"> mod = torch.utils.cpp_extension.load(...)</span>
|
||||
|
||||
<span class="sd"> :param cc: compute capability</span>
|
||||
<span class="sd"> :type cc: int</span>
|
||||
<span class="sd"> """</span>
|
||||
|
||||
<span class="n">_TORCH_CUDA_ARCH_LIST</span> <span class="o">=</span> <span class="s2">"TORCH_CUDA_ARCH_LIST"</span>
|
||||
|
||||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">cc</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">cc_str</span> <span class="o">=</span> <span class="s2">"."</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">cc</span><span class="p">)))</span>
|
||||
|
||||
<span class="k">def</span> <span class="fm">__enter__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Saves the old value of TORCH_CUDA_ARCH_LIST and reset it to the new value based on ``cc``</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">old_arch_list</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">getenv</span><span class="p">(</span><span class="n">_ArchListSetter</span><span class="o">.</span><span class="n">_TORCH_CUDA_ARCH_LIST</span><span class="p">)</span>
|
||||
<span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="n">_ArchListSetter</span><span class="o">.</span><span class="n">_TORCH_CUDA_ARCH_LIST</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cc_str</span>
|
||||
|
||||
<span class="k">return</span> <span class="bp">self</span>
|
||||
|
||||
<span class="k">def</span> <span class="fm">__exit__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">exc_type</span><span class="p">,</span> <span class="n">exc_val</span><span class="p">,</span> <span class="n">traceback</span><span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Restores the old value of TORCH_CUDA_ARCH_LIST</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="n">_ArchListSetter</span><span class="o">.</span><span class="n">_TORCH_CUDA_ARCH_LIST</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">old_arch_list</span>
|
||||
|
||||
|
||||
<span class="k">def</span> <span class="nf">_jit</span><span class="p">(</span><span class="n">name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">cc</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">cpp_file</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">cuda_file</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> JIT compiles and loads a PyTorch CUDA extension.</span>
|
||||
|
||||
<span class="sd"> :param name: name of the module to generate</span>
|
||||
<span class="sd"> :type name: str</span>
|
||||
<span class="sd"> :param cc: compute capability of the device the module should target</span>
|
||||
<span class="sd"> :type cc: int</span>
|
||||
<span class="sd"> :param cpp_file: path to file containing extension's C++ interface</span>
|
||||
<span class="sd"> :type cpp_file: str</span>
|
||||
<span class="sd"> :param cuda_file: path to file containing extension's CUDA interface</span>
|
||||
<span class="sd"> :type cuda_file: str</span>
|
||||
|
||||
<span class="sd"> :return: loaded PyTorch module</span>
|
||||
<span class="sd"> """</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">torch.utils.cpp_extension</span> <span class="kn">import</span> <span class="n">load</span>
|
||||
|
||||
<span class="n">extra_cuda_cflags</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"-std=c++17"</span><span class="p">]</span>
|
||||
<span class="k">if</span> <span class="n">cc</span> <span class="o">==</span> <span class="mi">90</span><span class="p">:</span>
|
||||
<span class="c1"># PyTorch does not currently add the sm_90a target when compute capability</span>
|
||||
<span class="c1"># 9.0 is set within TORCH_CUDA_ARCH_LIST. Thus, we manually add the sm_90a target.</span>
|
||||
<span class="n">extra_cuda_cflags</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="s2">"-gencode=arch=compute_90a,code=sm_90a"</span><span class="p">)</span>
|
||||
|
||||
<span class="k">with</span> <span class="n">_ArchListSetter</span><span class="p">(</span><span class="n">cc</span><span class="p">):</span>
|
||||
<span class="n">jitmodule</span> <span class="o">=</span> <span class="n">load</span><span class="p">(</span>
|
||||
<span class="n">name</span><span class="p">,</span>
|
||||
<span class="p">[</span><span class="n">cpp_file</span><span class="p">,</span> <span class="n">cuda_file</span><span class="p">],</span>
|
||||
<span class="n">extra_cuda_cflags</span><span class="o">=</span><span class="n">extra_cuda_cflags</span><span class="p">,</span>
|
||||
<span class="n">extra_include_paths</span><span class="o">=</span><span class="p">[</span>
|
||||
<span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">CUTLASS_PATH</span><span class="p">,</span> <span class="s2">"include"</span><span class="p">),</span>
|
||||
<span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">CUTLASS_PATH</span><span class="p">,</span> <span class="s2">"tools/util/include"</span><span class="p">),</span>
|
||||
<span class="p">],</span>
|
||||
<span class="n">verbose</span><span class="o">=</span><span class="p">(</span><span class="n">logger</span><span class="o">.</span><span class="n">level</span> <span class="o">==</span> <span class="n">logging</span><span class="o">.</span><span class="n">DEBUG</span><span class="p">)</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">jitmodule</span>
|
||||
|
||||
|
||||
<span class="k">def</span> <span class="nf">_pytorch_gemm</span><span class="p">(</span><span class="n">op</span><span class="p">,</span> <span class="n">name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">cc</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">jit</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> <span class="n">sourcedir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">""</span><span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Generates source for building a PyTorch CUDA module that leverages the CUTLASS GEMM</span>
|
||||
<span class="sd"> specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time</span>
|
||||
<span class="sd"> compiled, loaded, and returned.</span>
|
||||
|
||||
<span class="sd"> :param op: operation to emit in the module</span>
|
||||
<span class="sd"> :param name: name of the module to generate</span>
|
||||
<span class="sd"> :type name: str</span>
|
||||
<span class="sd"> :param cc: compute capability of the device the module should target</span>
|
||||
<span class="sd"> :type cc: int</span>
|
||||
<span class="sd"> :param jit: whether the module should be just-in-time compiled</span>
|
||||
<span class="sd"> :type jit: bool</span>
|
||||
<span class="sd"> :param sourcedir: directory to which generated source files should be written</span>
|
||||
<span class="sd"> :type sourcedir: str</span>
|
||||
|
||||
<span class="sd"> :return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">if</span> <span class="n">sourcedir</span> <span class="o">!=</span> <span class="s2">""</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">isdir</span><span class="p">(</span><span class="n">sourcedir</span><span class="p">):</span>
|
||||
<span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">sourcedir</span><span class="p">)</span>
|
||||
|
||||
<span class="n">cuda_file</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">sourcedir</span><span class="p">,</span> <span class="n">name</span> <span class="o">+</span> <span class="s2">"_kernel.cu"</span><span class="p">)</span>
|
||||
<span class="n">extra_kw</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
<span class="k">if</span> <span class="n">op</span><span class="o">.</span><span class="n">api</span> <span class="o">==</span> <span class="n">ApiVersion</span><span class="o">.</span><span class="n">v3x</span><span class="p">:</span>
|
||||
<span class="n">impl_template</span> <span class="o">=</span> <span class="n">_PYTORCH_GEMM_IMPL_TEMPLATE_3x</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">impl_template</span> <span class="o">=</span> <span class="n">_PYTORCH_GEMM_IMPL_TEMPLATE_2x</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">op</span><span class="o">.</span><span class="n">swizzling_functor</span><span class="p">,</span> <span class="n">swizzle</span><span class="o">.</span><span class="n">ThreadblockSwizzleStreamK</span><span class="p">):</span>
|
||||
<span class="n">extra_kw</span><span class="p">[</span><span class="s2">"args"</span><span class="p">]</span> <span class="o">=</span> <span class="n">common</span><span class="o">.</span><span class="n">_CUTLASS_KERNEL_ARGS_2x_STREAM_K</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">extra_kw</span><span class="p">[</span><span class="s2">"args"</span><span class="p">]</span> <span class="o">=</span> <span class="n">common</span><span class="o">.</span><span class="n">_CUTLASS_KERNEL_ARGS_2x</span>
|
||||
<span class="n">impl_template</span> <span class="o">=</span> <span class="p">(</span>
|
||||
<span class="n">_PYTORCH_GEMM_IMPL_TEMPLATE_3x</span>
|
||||
<span class="k">if</span> <span class="n">op</span><span class="o">.</span><span class="n">api</span> <span class="o">==</span> <span class="n">ApiVersion</span><span class="o">.</span><span class="n">v3x</span>
|
||||
<span class="k">else</span> <span class="n">_PYTORCH_GEMM_IMPL_TEMPLATE_2x</span>
|
||||
<span class="p">)</span>
|
||||
<span class="n">cuda_impl</span> <span class="o">=</span> <span class="n">SubstituteTemplate</span><span class="p">(</span><span class="n">impl_template</span><span class="p">,</span> <span class="p">{</span><span class="s2">"name"</span><span class="p">:</span> <span class="n">name</span><span class="p">,</span> <span class="o">**</span><span class="n">extra_kw</span><span class="p">})</span>
|
||||
<span class="n">cuda_source</span> <span class="o">=</span> <span class="n">SubstituteTemplate</span><span class="p">(</span>
|
||||
<span class="n">_PYTORCH_CUDA_TEMPLATE</span><span class="p">,</span>
|
||||
<span class="p">{</span>
|
||||
<span class="s2">"includes"</span><span class="p">:</span> <span class="n">_PYTORCH_GEMM_INCLUDES</span><span class="p">[</span><span class="n">op</span><span class="o">.</span><span class="n">api</span><span class="p">],</span>
|
||||
<span class="s2">"declaration"</span><span class="p">:</span> <span class="n">op</span><span class="o">.</span><span class="n">rt_module</span><span class="o">.</span><span class="n">emit</span><span class="p">(),</span>
|
||||
<span class="s2">"procedural_name"</span><span class="p">:</span> <span class="n">op</span><span class="o">.</span><span class="n">procedural_name</span><span class="p">(),</span>
|
||||
<span class="s2">"impl"</span><span class="p">:</span> <span class="n">cuda_impl</span><span class="p">,</span>
|
||||
<span class="s2">"torch_type_C"</span><span class="p">:</span> <span class="n">_CUTLASS_TYPE_TO_TORCH_TYPE</span><span class="p">[</span><span class="n">op</span><span class="o">.</span><span class="n">C</span><span class="o">.</span><span class="n">element</span><span class="p">],</span>
|
||||
<span class="p">},</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">cuda_file</span><span class="p">,</span> <span class="s2">"w"</span><span class="p">)</span> <span class="k">as</span> <span class="n">outfile</span><span class="p">:</span>
|
||||
<span class="n">outfile</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">cuda_source</span><span class="p">)</span>
|
||||
|
||||
<span class="n">cpp_file</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">sourcedir</span><span class="p">,</span> <span class="n">name</span> <span class="o">+</span> <span class="s2">".cpp"</span><span class="p">)</span>
|
||||
<span class="n">cpp_source</span> <span class="o">=</span> <span class="n">SubstituteTemplate</span><span class="p">(</span>
|
||||
<span class="n">_PYTORCH_GEMM_CPP_TEMPLATE</span><span class="p">,</span>
|
||||
<span class="p">{</span><span class="s2">"name"</span><span class="p">:</span> <span class="n">name</span><span class="p">,</span> <span class="s2">"description"</span><span class="p">:</span> <span class="sa">f</span><span class="s2">"CUTLASS </span><span class="si">{</span><span class="n">op</span><span class="o">.</span><span class="n">procedural_name</span><span class="p">()</span><span class="si">}</span><span class="s2"> GEMM"</span><span class="p">},</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">cpp_file</span><span class="p">,</span> <span class="s2">"w"</span><span class="p">)</span> <span class="k">as</span> <span class="n">outfile</span><span class="p">:</span>
|
||||
<span class="n">outfile</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">cpp_source</span><span class="p">)</span>
|
||||
|
||||
<span class="n">_generate_setup</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">sourcedir</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">jit</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">_jit</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">cc</span><span class="p">,</span> <span class="n">cpp_file</span><span class="p">,</span> <span class="n">cuda_file</span><span class="p">)</span>
|
||||
|
||||
<span class="k">return</span> <span class="kc">None</span>
|
||||
|
||||
|
||||
<span class="k">def</span> <span class="nf">_pytorch_grouped_gemm</span><span class="p">(</span>
|
||||
<span class="n">op</span><span class="p">,</span> <span class="n">name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">cc</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">jit</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> <span class="n">sourcedir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">""</span>
|
||||
<span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Generates source for building a PyTorch CUDA module that leverages the CUTLASS grouped GEMM</span>
|
||||
<span class="sd"> specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time</span>
|
||||
<span class="sd"> compiled, loaded, and returned.</span>
|
||||
|
||||
<span class="sd"> :param op: operation to emit in the module</span>
|
||||
<span class="sd"> :param name: name of the module to generate</span>
|
||||
<span class="sd"> :type name: str</span>
|
||||
<span class="sd"> :param cc: compute capability of the device the module should target</span>
|
||||
<span class="sd"> :type cc: int</span>
|
||||
<span class="sd"> :param jit: whether the module should be just-in-time compiled</span>
|
||||
<span class="sd"> :type jit: bool</span>
|
||||
<span class="sd"> :param sourcedir: directory to which generated source files should be written</span>
|
||||
<span class="sd"> :type sourcedir: str</span>
|
||||
|
||||
<span class="sd"> :return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">if</span> <span class="n">op</span><span class="o">.</span><span class="n">api</span> <span class="o">!=</span> <span class="n">ApiVersion</span><span class="o">.</span><span class="n">v2x</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s2">"Grouped GEMM is currently only supported for CUTLASS 2.x"</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">sourcedir</span> <span class="o">!=</span> <span class="s2">""</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">isdir</span><span class="p">(</span><span class="n">sourcedir</span><span class="p">):</span>
|
||||
<span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">sourcedir</span><span class="p">)</span>
|
||||
|
||||
<span class="n">cuda_file</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">sourcedir</span><span class="p">,</span> <span class="n">name</span> <span class="o">+</span> <span class="s2">"_kernel.cu"</span><span class="p">)</span>
|
||||
<span class="n">cuda_impl</span> <span class="o">=</span> <span class="n">SubstituteTemplate</span><span class="p">(</span><span class="n">_PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE</span><span class="p">,</span> <span class="p">{</span><span class="s2">"name"</span><span class="p">:</span> <span class="n">name</span><span class="p">})</span>
|
||||
<span class="n">cuda_source</span> <span class="o">=</span> <span class="n">SubstituteTemplate</span><span class="p">(</span>
|
||||
<span class="n">_PYTORCH_CUDA_TEMPLATE</span><span class="p">,</span>
|
||||
<span class="p">{</span>
|
||||
<span class="s2">"includes"</span><span class="p">:</span> <span class="n">_PYTORCH_GROUPED_GEMM_INCLUDES</span><span class="p">,</span>
|
||||
<span class="s2">"declaration"</span><span class="p">:</span> <span class="n">op</span><span class="o">.</span><span class="n">rt_module</span><span class="o">.</span><span class="n">emit</span><span class="p">(),</span>
|
||||
<span class="s2">"procedural_name"</span><span class="p">:</span> <span class="n">op</span><span class="o">.</span><span class="n">procedural_name</span><span class="p">(),</span>
|
||||
<span class="s2">"impl"</span><span class="p">:</span> <span class="n">cuda_impl</span><span class="p">,</span>
|
||||
<span class="s2">"torch_type_C"</span><span class="p">:</span> <span class="n">_CUTLASS_TYPE_TO_TORCH_TYPE</span><span class="p">[</span><span class="n">op</span><span class="o">.</span><span class="n">C</span><span class="o">.</span><span class="n">element</span><span class="p">],</span>
|
||||
<span class="p">},</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">cuda_file</span><span class="p">,</span> <span class="s2">"w"</span><span class="p">)</span> <span class="k">as</span> <span class="n">outfile</span><span class="p">:</span>
|
||||
<span class="n">outfile</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">cuda_source</span><span class="p">)</span>
|
||||
|
||||
<span class="n">cpp_file</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">sourcedir</span><span class="p">,</span> <span class="n">name</span> <span class="o">+</span> <span class="s2">".cpp"</span><span class="p">)</span>
|
||||
<span class="n">cpp_source</span> <span class="o">=</span> <span class="n">SubstituteTemplate</span><span class="p">(</span>
|
||||
<span class="n">_PYTORCH_GROUPED_GEMM_CPP_TEMPLATE</span><span class="p">,</span>
|
||||
<span class="p">{</span><span class="s2">"name"</span><span class="p">:</span> <span class="n">name</span><span class="p">,</span> <span class="s2">"description"</span><span class="p">:</span> <span class="sa">f</span><span class="s2">"CUTLASS </span><span class="si">{</span><span class="n">op</span><span class="o">.</span><span class="n">procedural_name</span><span class="p">()</span><span class="si">}</span><span class="s2"> grouped GEMM"</span><span class="p">},</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">cpp_file</span><span class="p">,</span> <span class="s2">"w"</span><span class="p">)</span> <span class="k">as</span> <span class="n">outfile</span><span class="p">:</span>
|
||||
<span class="n">outfile</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">cpp_source</span><span class="p">)</span>
|
||||
|
||||
<span class="n">_generate_setup</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">sourcedir</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">jit</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">_jit</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">cc</span><span class="p">,</span> <span class="n">cpp_file</span><span class="p">,</span> <span class="n">cuda_file</span><span class="p">)</span>
|
||||
|
||||
<span class="k">return</span> <span class="kc">None</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="pytorch"><a class="viewcode-back" href="../../../cutlass.emit.html#cutlass.emit.pytorch.pytorch">[docs]</a><span class="k">def</span> <span class="nf">pytorch</span><span class="p">(</span><span class="n">op</span><span class="p">,</span> <span class="n">name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">cc</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">jit</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> <span class="n">sourcedir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">""</span><span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Generates source for building a PyTorch CUDA module that leverages the CUTLASS kernel</span>
|
||||
<span class="sd"> specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time</span>
|
||||
<span class="sd"> compiled, loaded, and returned.</span>
|
||||
|
||||
<span class="sd"> The result of this method is files within ``sourcedir`` that can be used for building</span>
|
||||
<span class="sd"> a PyTorch module.</span>
|
||||
|
||||
<span class="sd"> :param op: operation to emit in the module</span>
|
||||
<span class="sd"> :param name: name of the module to generate</span>
|
||||
<span class="sd"> :type name: str</span>
|
||||
<span class="sd"> :param cc: compute capability of the device the module should target</span>
|
||||
<span class="sd"> :type cc: int</span>
|
||||
<span class="sd"> :param jit: whether the module should be just-in-time compiled</span>
|
||||
<span class="sd"> :type jit: bool</span>
|
||||
<span class="sd"> :param sourcedir: directory to which generated source files should be written</span>
|
||||
<span class="sd"> :type sourcedir: str</span>
|
||||
|
||||
<span class="sd"> :return: loaded PyTorch module (if ``jit=True``) or None</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="n">device_op</span> <span class="o">=</span> <span class="n">op</span><span class="o">.</span><span class="n">device_op</span><span class="p">()</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">op</span><span class="p">,</span> <span class="n">GemmOperationUniversal</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">_pytorch_gemm</span><span class="p">(</span><span class="n">device_op</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">cc</span><span class="p">,</span> <span class="n">jit</span><span class="p">,</span> <span class="n">sourcedir</span><span class="p">)</span>
|
||||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">op</span><span class="p">,</span> <span class="n">GemmOperationGrouped</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">_pytorch_grouped_gemm</span><span class="p">(</span><span class="n">device_op</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">cc</span><span class="p">,</span> <span class="n">jit</span><span class="p">,</span> <span class="n">sourcedir</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span>
|
||||
<span class="sa">f</span><span class="s2">"Operation type </span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="n">op</span><span class="p">)</span><span class="si">}</span><span class="s2"> is not currently supported for PyTorch emission."</span>
|
||||
<span class="p">)</span></div>
|
||||
</pre></div>
|
||||
</article>
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
<div class="related-pages">
|
||||
|
||||
|
||||
</div>
|
||||
<div class="bottom-of-page">
|
||||
<div class="left-details">
|
||||
<div class="copyright">
|
||||
Copyright © 2023, NVIDIA
|
||||
</div>
|
||||
Made with <a href="https://www.sphinx-doc.org/">Sphinx</a> and <a class="muted-link" href="https://pradyunsg.me">@pradyunsg</a>'s
|
||||
|
||||
<a href="https://github.com/pradyunsg/furo">Furo</a>
|
||||
|
||||
</div>
|
||||
<div class="right-details">
|
||||
<div class="icons">
|
||||
<a class="muted-link " href="https://github.com/NVIDIA/cutlass" aria-label="GitHub">
|
||||
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 16 16">
|
||||
<path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z"></path>
|
||||
</svg>
|
||||
</a>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</footer>
|
||||
</div>
|
||||
<aside class="toc-drawer no-toc">
|
||||
|
||||
|
||||
|
||||
</aside>
|
||||
</div>
|
||||
</div><script data-url_root="../../../" id="documentation_options" src="../../../_static/documentation_options.js"></script>
|
||||
<script src="../../../_static/doctools.js"></script>
|
||||
<script src="../../../_static/sphinx_highlight.js"></script>
|
||||
<script src="../../../_static/scripts/furo.js"></script>
|
||||
<script src="../../../_static/clipboard.min.js"></script>
|
||||
<script src="../../../_static/copybutton.js"></script>
|
||||
<script src="../../../_static/tabs.js"></script>
|
||||
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
391
python/docs/_modules/cutlass/epilogue.html
Normal file
391
python/docs/_modules/cutlass/epilogue.html
Normal file
@ -0,0 +1,391 @@
|
||||
<!doctype html>
|
||||
<html class="no-js" lang="en">
|
||||
<head><meta charset="utf-8"/>
|
||||
<meta name="viewport" content="width=device-width,initial-scale=1"/>
|
||||
<meta name="color-scheme" content="light dark"><link rel="index" title="Index" href="../../genindex.html" /><link rel="search" title="Search" href="../../search.html" />
|
||||
<link rel="canonical" href="docs/_modules/cutlass/epilogue.html" />
|
||||
|
||||
<!-- Generated with Sphinx 6.1.3 and Furo 2023.03.27 -->
|
||||
<title>cutlass.epilogue - CUTLASS Python</title>
|
||||
<link rel="stylesheet" type="text/css" href="../../_static/pygments.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../../_static/styles/furo.css?digest=fad236701ea90a88636c2a8c73b44ae642ed2a53" />
|
||||
<link rel="stylesheet" type="text/css" href="../../_static/copybutton.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../../_static/tabs.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../../_static/styles/furo-extensions.css?digest=30d1aed668e5c3a91c3e3bf6a60b675221979f0e" />
|
||||
|
||||
|
||||
|
||||
|
||||
<style>
|
||||
body {
|
||||
--color-code-background: #eeffcc;
|
||||
--color-code-foreground: black;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
@media not print {
|
||||
body[data-theme="dark"] {
|
||||
--color-code-background: #272822;
|
||||
--color-code-foreground: #f8f8f2;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
@media (prefers-color-scheme: dark) {
|
||||
body:not([data-theme="light"]) {
|
||||
--color-code-background: #272822;
|
||||
--color-code-foreground: #f8f8f2;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
</style></head>
|
||||
<body>
|
||||
|
||||
<script>
|
||||
document.body.dataset.theme = localStorage.getItem("theme") || "auto";
|
||||
</script>
|
||||
|
||||
|
||||
<svg xmlns="http://www.w3.org/2000/svg" style="display: none;">
|
||||
<symbol id="svg-toc" viewBox="0 0 24 24">
|
||||
<title>Contents</title>
|
||||
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 1024 1024">
|
||||
<path d="M408 442h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8zm-8 204c0 4.4 3.6 8 8 8h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56zm504-486H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zm0 632H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zM115.4 518.9L271.7 642c5.8 4.6 14.4.5 14.4-6.9V388.9c0-7.4-8.5-11.5-14.4-6.9L115.4 505.1a8.74 8.74 0 0 0 0 13.8z"/>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-menu" viewBox="0 0 24 24">
|
||||
<title>Menu</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-menu">
|
||||
<line x1="3" y1="12" x2="21" y2="12"></line>
|
||||
<line x1="3" y1="6" x2="21" y2="6"></line>
|
||||
<line x1="3" y1="18" x2="21" y2="18"></line>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-arrow-right" viewBox="0 0 24 24">
|
||||
<title>Expand</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-chevron-right">
|
||||
<polyline points="9 18 15 12 9 6"></polyline>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-sun" viewBox="0 0 24 24">
|
||||
<title>Light mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather-sun">
|
||||
<circle cx="12" cy="12" r="5"></circle>
|
||||
<line x1="12" y1="1" x2="12" y2="3"></line>
|
||||
<line x1="12" y1="21" x2="12" y2="23"></line>
|
||||
<line x1="4.22" y1="4.22" x2="5.64" y2="5.64"></line>
|
||||
<line x1="18.36" y1="18.36" x2="19.78" y2="19.78"></line>
|
||||
<line x1="1" y1="12" x2="3" y2="12"></line>
|
||||
<line x1="21" y1="12" x2="23" y2="12"></line>
|
||||
<line x1="4.22" y1="19.78" x2="5.64" y2="18.36"></line>
|
||||
<line x1="18.36" y1="5.64" x2="19.78" y2="4.22"></line>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-moon" viewBox="0 0 24 24">
|
||||
<title>Dark mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-moon">
|
||||
<path stroke="none" d="M0 0h24v24H0z" fill="none" />
|
||||
<path d="M12 3c.132 0 .263 0 .393 0a7.5 7.5 0 0 0 7.92 12.446a9 9 0 1 1 -8.313 -12.454z" />
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-sun-half" viewBox="0 0 24 24">
|
||||
<title>Auto light/dark mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-shadow">
|
||||
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
|
||||
<circle cx="12" cy="12" r="9" />
|
||||
<path d="M13 12h5" />
|
||||
<path d="M13 15h4" />
|
||||
<path d="M13 18h1" />
|
||||
<path d="M13 9h4" />
|
||||
<path d="M13 6h1" />
|
||||
</svg>
|
||||
</symbol>
|
||||
</svg>
|
||||
|
||||
<input type="checkbox" class="sidebar-toggle" name="__navigation" id="__navigation">
|
||||
<input type="checkbox" class="sidebar-toggle" name="__toc" id="__toc">
|
||||
<label class="overlay sidebar-overlay" for="__navigation">
|
||||
<div class="visually-hidden">Hide navigation sidebar</div>
|
||||
</label>
|
||||
<label class="overlay toc-overlay" for="__toc">
|
||||
<div class="visually-hidden">Hide table of contents sidebar</div>
|
||||
</label>
|
||||
|
||||
|
||||
|
||||
<div class="page">
|
||||
<header class="mobile-header">
|
||||
<div class="header-left">
|
||||
<label class="nav-overlay-icon" for="__navigation">
|
||||
<div class="visually-hidden">Toggle site navigation sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-menu"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
<div class="header-center">
|
||||
<a href="../../index.html"><div class="brand">CUTLASS Python</div></a>
|
||||
</div>
|
||||
<div class="header-right">
|
||||
<div class="theme-toggle-container theme-toggle-header">
|
||||
<button class="theme-toggle">
|
||||
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
|
||||
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
|
||||
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
|
||||
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
|
||||
</button>
|
||||
</div>
|
||||
<label class="toc-overlay-icon toc-header-icon no-toc" for="__toc">
|
||||
<div class="visually-hidden">Toggle table of contents sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
</header>
|
||||
<aside class="sidebar-drawer">
|
||||
<div class="sidebar-container">
|
||||
|
||||
<div class="sidebar-sticky"><a class="sidebar-brand" href="../../index.html">
|
||||
|
||||
<div class="sidebar-logo-container">
|
||||
<img class="sidebar-logo only-light" src="../../_static/cutlass-logo-small.png" alt="Light Logo"/>
|
||||
<img class="sidebar-logo only-dark" src="../../_static/cutlass-logo-small.png" alt="Dark Logo"/>
|
||||
</div>
|
||||
|
||||
<span class="sidebar-brand-text">CUTLASS Python</span>
|
||||
|
||||
</a><form class="sidebar-search-container" method="get" action="../../search.html" role="search">
|
||||
<input class="sidebar-search" placeholder="Search" name="q" aria-label="Search">
|
||||
<input type="hidden" name="check_keywords" value="yes">
|
||||
<input type="hidden" name="area" value="default">
|
||||
</form>
|
||||
<div id="searchbox"></div><div class="sidebar-scroll"><div class="sidebar-tree">
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../index.html">Home</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../install.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../externals/00_basic_gemm.html">Getting Started</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../contribute.html">Contributing</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python Documentation:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../../modules.html">CUTLASS Python API</a><input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" role="switch" type="checkbox"/><label for="toctree-checkbox-1"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l2 has-children"><a class="reference internal" href="../../cutlass.html">CUTLASS</a><input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" role="switch" type="checkbox"/><label for="toctree-checkbox-2"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../../cutlass.emit.html">Emitters</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../../cutlass.op.html">Operations</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../../cutlass.utils.html">Utilities</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Examples and Tutorials:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../../examples.html">Examples</a><input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" role="switch" type="checkbox"/><label for="toctree-checkbox-3"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../../externals/00_basic_gemm.html">Basic GEMM</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../../externals/01_epilogue.html">Epilogue</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../../externals/02_pytorch_extension_grouped_gemm.html">PyTorch Extension</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Reference:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/cutlass">Github</a></li>
|
||||
</ul>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</aside>
|
||||
<div class="main">
|
||||
<div class="content">
|
||||
<div class="article-container">
|
||||
<a href="#" class="back-to-top muted-link">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24">
|
||||
<path d="M13 20h-2V8l-5.5 5.5-1.42-1.42L12 4.16l7.92 7.92-1.42 1.42L13 8v12z"></path>
|
||||
</svg>
|
||||
<span>Back to top</span>
|
||||
</a>
|
||||
<div class="content-icon-container">
|
||||
<div class="theme-toggle-container theme-toggle-content">
|
||||
<button class="theme-toggle">
|
||||
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
|
||||
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
|
||||
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
|
||||
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
|
||||
</button>
|
||||
</div>
|
||||
<label class="toc-overlay-icon toc-content-icon no-toc" for="__toc">
|
||||
<div class="visually-hidden">Toggle table of contents sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
<article role="main">
|
||||
<h1>Source code for cutlass.epilogue</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#################################################################################################</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.</span>
|
||||
<span class="c1"># SPDX-License-Identifier: BSD-3-Clause</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Redistribution and use in source and binary forms, with or without</span>
|
||||
<span class="c1"># modification, are permitted provided that the following conditions are met:</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># 1. Redistributions of source code must retain the above copyright notice, this</span>
|
||||
<span class="c1"># list of conditions and the following disclaimer.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># 2. Redistributions in binary form must reproduce the above copyright notice,</span>
|
||||
<span class="c1"># this list of conditions and the following disclaimer in the documentation</span>
|
||||
<span class="c1"># and/or other materials provided with the distribution.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># 3. Neither the name of the copyright holder nor the names of its</span>
|
||||
<span class="c1"># contributors may be used to endorse or promote products derived from</span>
|
||||
<span class="c1"># this software without specific prior written permission.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"</span>
|
||||
<span class="c1"># AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE</span>
|
||||
<span class="c1"># IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE</span>
|
||||
<span class="c1"># DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE</span>
|
||||
<span class="c1"># FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL</span>
|
||||
<span class="c1"># DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR</span>
|
||||
<span class="c1"># SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER</span>
|
||||
<span class="c1"># CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,</span>
|
||||
<span class="c1"># OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE</span>
|
||||
<span class="c1"># OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1">#################################################################################################</span>
|
||||
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd">Registry of elementwise epilogues</span>
|
||||
|
||||
<span class="sd">Elementwise epilogues can be added to many CUTLASS kernels in the CUTLAS Python interface via</span>
|
||||
<span class="sd">code like the following for GEMM:</span>
|
||||
|
||||
<span class="sd">.. highlight:: python</span>
|
||||
<span class="sd">.. code-block:: python</span>
|
||||
|
||||
<span class="sd"> plan = cutlass.op.Gemm(element=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor)</span>
|
||||
<span class="sd"> plan.activation = cutlass.epilogue.relu</span>
|
||||
<span class="sd">"""</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">cutlass.backend</span> <span class="kn">import</span> <span class="n">epilogue</span>
|
||||
|
||||
<span class="n">gelu</span> <span class="o">=</span> <span class="n">epilogue</span><span class="o">.</span><span class="n">gelu</span>
|
||||
<span class="n">hardswish</span> <span class="o">=</span> <span class="n">epilogue</span><span class="o">.</span><span class="n">hardswish</span>
|
||||
<span class="n">identity</span> <span class="o">=</span> <span class="n">epilogue</span><span class="o">.</span><span class="n">identity</span>
|
||||
<span class="n">leaky_relu</span> <span class="o">=</span> <span class="n">epilogue</span><span class="o">.</span><span class="n">leaky_relu</span>
|
||||
<span class="n">relu</span> <span class="o">=</span> <span class="n">epilogue</span><span class="o">.</span><span class="n">relu</span>
|
||||
<span class="n">sigmoid</span> <span class="o">=</span> <span class="n">epilogue</span><span class="o">.</span><span class="n">sigmoid</span>
|
||||
<span class="n">silu</span> <span class="o">=</span> <span class="n">epilogue</span><span class="o">.</span><span class="n">silu</span>
|
||||
<span class="n">tanh</span> <span class="o">=</span> <span class="n">epilogue</span><span class="o">.</span><span class="n">tanh</span>
|
||||
|
||||
|
||||
<span class="n">_activations</span> <span class="o">=</span> <span class="p">[</span><span class="n">gelu</span><span class="p">,</span> <span class="n">hardswish</span><span class="p">,</span> <span class="n">identity</span><span class="p">,</span> <span class="n">leaky_relu</span><span class="p">,</span> <span class="n">relu</span><span class="p">,</span> <span class="n">sigmoid</span><span class="p">,</span> <span class="n">silu</span><span class="p">,</span> <span class="n">tanh</span><span class="p">]</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="get_activations"><a class="viewcode-back" href="../../cutlass.html#cutlass.epilogue.get_activations">[docs]</a><span class="k">def</span> <span class="nf">get_activations</span><span class="p">()</span> <span class="o">-></span> <span class="nb">list</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Returns a list of available activation functions</span>
|
||||
|
||||
<span class="sd"> :return: list of available activation functions</span>
|
||||
<span class="sd"> :rtype: list</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">return</span> <span class="n">_activations</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="get_activation_epilogue"><a class="viewcode-back" href="../../cutlass.html#cutlass.epilogue.get_activation_epilogue">[docs]</a><span class="k">def</span> <span class="nf">get_activation_epilogue</span><span class="p">(</span>
|
||||
<span class="n">activation</span><span class="p">,</span>
|
||||
<span class="n">element_output</span><span class="p">,</span>
|
||||
<span class="n">elements_per_access</span><span class="p">,</span>
|
||||
<span class="n">element_accumulator</span><span class="p">,</span>
|
||||
<span class="n">element_compute</span><span class="p">,</span>
|
||||
<span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Return an epilogue corresponding to the activation function, data types, and alignment</span>
|
||||
<span class="sd"> used in the kernel</span>
|
||||
|
||||
<span class="sd"> :param activation: elementwise activation function to use</span>
|
||||
<span class="sd"> :param element_output: data type of the output</span>
|
||||
<span class="sd"> :param elements_per_access: alignment of operand C of the kernel</span>
|
||||
<span class="sd"> :type elements_per_access: int</span>
|
||||
<span class="sd"> :param element_accumulator: data type of the accumulated output C</span>
|
||||
<span class="sd"> :param element_compute: data type in which compute operations should be performed</span>
|
||||
|
||||
<span class="sd"> :return: epilogue functor</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">if</span> <span class="n">activation</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">_activations</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span>
|
||||
<span class="sa">f</span><span class="s2">"Unsupported activation type </span><span class="si">{</span><span class="n">activation</span><span class="si">}</span><span class="s2">. Available activations are: </span><span class="si">{</span><span class="n">_activations</span><span class="si">}</span><span class="s2">"</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">activation</span> <span class="o">==</span> <span class="n">identity</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">epilogue</span><span class="o">.</span><span class="n">LinearCombination</span><span class="p">(</span>
|
||||
<span class="n">element_output</span><span class="p">,</span> <span class="n">elements_per_access</span><span class="p">,</span> <span class="n">element_accumulator</span><span class="p">,</span> <span class="n">element_compute</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">epilogue</span><span class="o">.</span><span class="n">LinearCombinationGeneric</span><span class="p">(</span>
|
||||
<span class="n">activation</span><span class="p">(</span><span class="n">element_compute</span><span class="p">),</span>
|
||||
<span class="n">element_output</span><span class="p">,</span>
|
||||
<span class="n">elements_per_access</span><span class="p">,</span>
|
||||
<span class="n">element_accumulator</span><span class="p">,</span>
|
||||
<span class="n">element_compute</span><span class="p">,</span>
|
||||
<span class="p">)</span></div>
|
||||
</pre></div>
|
||||
</article>
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
<div class="related-pages">
|
||||
|
||||
|
||||
</div>
|
||||
<div class="bottom-of-page">
|
||||
<div class="left-details">
|
||||
<div class="copyright">
|
||||
Copyright © 2023, NVIDIA
|
||||
</div>
|
||||
Made with <a href="https://www.sphinx-doc.org/">Sphinx</a> and <a class="muted-link" href="https://pradyunsg.me">@pradyunsg</a>'s
|
||||
|
||||
<a href="https://github.com/pradyunsg/furo">Furo</a>
|
||||
|
||||
</div>
|
||||
<div class="right-details">
|
||||
<div class="icons">
|
||||
<a class="muted-link " href="https://github.com/NVIDIA/cutlass" aria-label="GitHub">
|
||||
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 16 16">
|
||||
<path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z"></path>
|
||||
</svg>
|
||||
</a>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</footer>
|
||||
</div>
|
||||
<aside class="toc-drawer no-toc">
|
||||
|
||||
|
||||
|
||||
</aside>
|
||||
</div>
|
||||
</div><script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
<script src="../../_static/sphinx_highlight.js"></script>
|
||||
<script src="../../_static/scripts/furo.js"></script>
|
||||
<script src="../../_static/clipboard.min.js"></script>
|
||||
<script src="../../_static/copybutton.js"></script>
|
||||
<script src="../../_static/tabs.js"></script>
|
||||
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
729
python/docs/_modules/cutlass/library_defaults.html
Normal file
729
python/docs/_modules/cutlass/library_defaults.html
Normal file
@ -0,0 +1,729 @@
|
||||
<!doctype html>
|
||||
<html class="no-js" lang="en">
|
||||
<head><meta charset="utf-8"/>
|
||||
<meta name="viewport" content="width=device-width,initial-scale=1"/>
|
||||
<meta name="color-scheme" content="light dark"><link rel="index" title="Index" href="../../genindex.html" /><link rel="search" title="Search" href="../../search.html" />
|
||||
<link rel="canonical" href="docs/_modules/cutlass/library_defaults.html" />
|
||||
|
||||
<!-- Generated with Sphinx 6.1.3 and Furo 2023.03.27 -->
|
||||
<title>cutlass.library_defaults - CUTLASS Python</title>
|
||||
<link rel="stylesheet" type="text/css" href="../../_static/pygments.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../../_static/styles/furo.css?digest=fad236701ea90a88636c2a8c73b44ae642ed2a53" />
|
||||
<link rel="stylesheet" type="text/css" href="../../_static/copybutton.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../../_static/tabs.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../../_static/styles/furo-extensions.css?digest=30d1aed668e5c3a91c3e3bf6a60b675221979f0e" />
|
||||
|
||||
|
||||
|
||||
|
||||
<style>
|
||||
body {
|
||||
--color-code-background: #eeffcc;
|
||||
--color-code-foreground: black;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
@media not print {
|
||||
body[data-theme="dark"] {
|
||||
--color-code-background: #272822;
|
||||
--color-code-foreground: #f8f8f2;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
@media (prefers-color-scheme: dark) {
|
||||
body:not([data-theme="light"]) {
|
||||
--color-code-background: #272822;
|
||||
--color-code-foreground: #f8f8f2;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
</style></head>
|
||||
<body>
|
||||
|
||||
<script>
|
||||
document.body.dataset.theme = localStorage.getItem("theme") || "auto";
|
||||
</script>
|
||||
|
||||
|
||||
<svg xmlns="http://www.w3.org/2000/svg" style="display: none;">
|
||||
<symbol id="svg-toc" viewBox="0 0 24 24">
|
||||
<title>Contents</title>
|
||||
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 1024 1024">
|
||||
<path d="M408 442h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8zm-8 204c0 4.4 3.6 8 8 8h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56zm504-486H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zm0 632H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zM115.4 518.9L271.7 642c5.8 4.6 14.4.5 14.4-6.9V388.9c0-7.4-8.5-11.5-14.4-6.9L115.4 505.1a8.74 8.74 0 0 0 0 13.8z"/>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-menu" viewBox="0 0 24 24">
|
||||
<title>Menu</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-menu">
|
||||
<line x1="3" y1="12" x2="21" y2="12"></line>
|
||||
<line x1="3" y1="6" x2="21" y2="6"></line>
|
||||
<line x1="3" y1="18" x2="21" y2="18"></line>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-arrow-right" viewBox="0 0 24 24">
|
||||
<title>Expand</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-chevron-right">
|
||||
<polyline points="9 18 15 12 9 6"></polyline>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-sun" viewBox="0 0 24 24">
|
||||
<title>Light mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather-sun">
|
||||
<circle cx="12" cy="12" r="5"></circle>
|
||||
<line x1="12" y1="1" x2="12" y2="3"></line>
|
||||
<line x1="12" y1="21" x2="12" y2="23"></line>
|
||||
<line x1="4.22" y1="4.22" x2="5.64" y2="5.64"></line>
|
||||
<line x1="18.36" y1="18.36" x2="19.78" y2="19.78"></line>
|
||||
<line x1="1" y1="12" x2="3" y2="12"></line>
|
||||
<line x1="21" y1="12" x2="23" y2="12"></line>
|
||||
<line x1="4.22" y1="19.78" x2="5.64" y2="18.36"></line>
|
||||
<line x1="18.36" y1="5.64" x2="19.78" y2="4.22"></line>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-moon" viewBox="0 0 24 24">
|
||||
<title>Dark mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-moon">
|
||||
<path stroke="none" d="M0 0h24v24H0z" fill="none" />
|
||||
<path d="M12 3c.132 0 .263 0 .393 0a7.5 7.5 0 0 0 7.92 12.446a9 9 0 1 1 -8.313 -12.454z" />
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-sun-half" viewBox="0 0 24 24">
|
||||
<title>Auto light/dark mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-shadow">
|
||||
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
|
||||
<circle cx="12" cy="12" r="9" />
|
||||
<path d="M13 12h5" />
|
||||
<path d="M13 15h4" />
|
||||
<path d="M13 18h1" />
|
||||
<path d="M13 9h4" />
|
||||
<path d="M13 6h1" />
|
||||
</svg>
|
||||
</symbol>
|
||||
</svg>
|
||||
|
||||
<input type="checkbox" class="sidebar-toggle" name="__navigation" id="__navigation">
|
||||
<input type="checkbox" class="sidebar-toggle" name="__toc" id="__toc">
|
||||
<label class="overlay sidebar-overlay" for="__navigation">
|
||||
<div class="visually-hidden">Hide navigation sidebar</div>
|
||||
</label>
|
||||
<label class="overlay toc-overlay" for="__toc">
|
||||
<div class="visually-hidden">Hide table of contents sidebar</div>
|
||||
</label>
|
||||
|
||||
|
||||
|
||||
<div class="page">
|
||||
<header class="mobile-header">
|
||||
<div class="header-left">
|
||||
<label class="nav-overlay-icon" for="__navigation">
|
||||
<div class="visually-hidden">Toggle site navigation sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-menu"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
<div class="header-center">
|
||||
<a href="../../index.html"><div class="brand">CUTLASS Python</div></a>
|
||||
</div>
|
||||
<div class="header-right">
|
||||
<div class="theme-toggle-container theme-toggle-header">
|
||||
<button class="theme-toggle">
|
||||
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
|
||||
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
|
||||
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
|
||||
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
|
||||
</button>
|
||||
</div>
|
||||
<label class="toc-overlay-icon toc-header-icon no-toc" for="__toc">
|
||||
<div class="visually-hidden">Toggle table of contents sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
</header>
|
||||
<aside class="sidebar-drawer">
|
||||
<div class="sidebar-container">
|
||||
|
||||
<div class="sidebar-sticky"><a class="sidebar-brand" href="../../index.html">
|
||||
|
||||
<div class="sidebar-logo-container">
|
||||
<img class="sidebar-logo only-light" src="../../_static/cutlass-logo-small.png" alt="Light Logo"/>
|
||||
<img class="sidebar-logo only-dark" src="../../_static/cutlass-logo-small.png" alt="Dark Logo"/>
|
||||
</div>
|
||||
|
||||
<span class="sidebar-brand-text">CUTLASS Python</span>
|
||||
|
||||
</a><form class="sidebar-search-container" method="get" action="../../search.html" role="search">
|
||||
<input class="sidebar-search" placeholder="Search" name="q" aria-label="Search">
|
||||
<input type="hidden" name="check_keywords" value="yes">
|
||||
<input type="hidden" name="area" value="default">
|
||||
</form>
|
||||
<div id="searchbox"></div><div class="sidebar-scroll"><div class="sidebar-tree">
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../index.html">Home</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../install.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../externals/00_basic_gemm.html">Getting Started</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../contribute.html">Contributing</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python Documentation:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../../modules.html">CUTLASS Python API</a><input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" role="switch" type="checkbox"/><label for="toctree-checkbox-1"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l2 has-children"><a class="reference internal" href="../../cutlass.html">CUTLASS</a><input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" role="switch" type="checkbox"/><label for="toctree-checkbox-2"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../../cutlass.emit.html">Emitters</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../../cutlass.op.html">Operations</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../../cutlass.utils.html">Utilities</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Examples and Tutorials:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../../examples.html">Examples</a><input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" role="switch" type="checkbox"/><label for="toctree-checkbox-3"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../../externals/00_basic_gemm.html">Basic GEMM</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../../externals/01_epilogue.html">Epilogue</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../../externals/02_pytorch_extension_grouped_gemm.html">PyTorch Extension</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Reference:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/cutlass">Github</a></li>
|
||||
</ul>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</aside>
|
||||
<div class="main">
|
||||
<div class="content">
|
||||
<div class="article-container">
|
||||
<a href="#" class="back-to-top muted-link">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24">
|
||||
<path d="M13 20h-2V8l-5.5 5.5-1.42-1.42L12 4.16l7.92 7.92-1.42 1.42L13 8v12z"></path>
|
||||
</svg>
|
||||
<span>Back to top</span>
|
||||
</a>
|
||||
<div class="content-icon-container">
|
||||
<div class="theme-toggle-container theme-toggle-content">
|
||||
<button class="theme-toggle">
|
||||
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
|
||||
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
|
||||
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
|
||||
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
|
||||
</button>
|
||||
</div>
|
||||
<label class="toc-overlay-icon toc-content-icon no-toc" for="__toc">
|
||||
<div class="visually-hidden">Toggle table of contents sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
<article role="main">
|
||||
<h1>Source code for cutlass.library_defaults</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#################################################################################################</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.</span>
|
||||
<span class="c1"># SPDX-License-Identifier: BSD-3-Clause</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Redistribution and use in source and binary forms, with or without</span>
|
||||
<span class="c1"># modification, are permitted provided that the following conditions are met:</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># 1. Redistributions of source code must retain the above copyright notice, this</span>
|
||||
<span class="c1"># list of conditions and the following disclaimer.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># 2. Redistributions in binary form must reproduce the above copyright notice,</span>
|
||||
<span class="c1"># this list of conditions and the following disclaimer in the documentation</span>
|
||||
<span class="c1"># and/or other materials provided with the distribution.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># 3. Neither the name of the copyright holder nor the names of its</span>
|
||||
<span class="c1"># contributors may be used to endorse or promote products derived from</span>
|
||||
<span class="c1"># this software without specific prior written permission.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"</span>
|
||||
<span class="c1"># AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE</span>
|
||||
<span class="c1"># IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE</span>
|
||||
<span class="c1"># DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE</span>
|
||||
<span class="c1"># FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL</span>
|
||||
<span class="c1"># DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR</span>
|
||||
<span class="c1"># SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER</span>
|
||||
<span class="c1"># CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,</span>
|
||||
<span class="c1"># OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE</span>
|
||||
<span class="c1"># OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1">#################################################################################################</span>
|
||||
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd">Classes containing valid operations for a given compute capability and data types.</span>
|
||||
<span class="sd">"""</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">logging</span>
|
||||
<span class="kn">from</span> <span class="nn">cuda</span> <span class="kn">import</span> <span class="n">__version__</span>
|
||||
|
||||
<span class="c1"># Strip any additional information from the CUDA version</span>
|
||||
<span class="n">_cuda_version</span> <span class="o">=</span> <span class="n">__version__</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"rc"</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
|
||||
|
||||
<span class="c1"># Imports from CUTLASS profiler generator and manifest scripts</span>
|
||||
<span class="kn">import</span> <span class="nn">generator</span> <span class="k">as</span> <span class="nn">prof_generator</span>
|
||||
<span class="kn">import</span> <span class="nn">manifest</span> <span class="k">as</span> <span class="nn">prof_manifest</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">cutlass</span>
|
||||
<span class="kn">from</span> <span class="nn">cutlass.utils.check</span> <span class="kn">import</span> <span class="n">valid_stage_count</span>
|
||||
<span class="kn">from</span> <span class="nn">cutlass.utils.datatypes</span> <span class="kn">import</span> <span class="n">td_from_profiler_td</span><span class="p">,</span> <span class="n">td_from_profiler_op</span><span class="p">,</span> <span class="n">has_binding_type</span>
|
||||
|
||||
|
||||
<span class="n">_generator_ccs</span> <span class="o">=</span> <span class="p">[</span><span class="mi">50</span><span class="p">,</span> <span class="mi">60</span><span class="p">,</span> <span class="mi">61</span><span class="p">,</span> <span class="mi">70</span><span class="p">,</span> <span class="mi">75</span><span class="p">,</span> <span class="mi">80</span><span class="p">,</span> <span class="mi">90</span><span class="p">]</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="KernelsForDataType"><a class="viewcode-back" href="../../cutlass.html#cutlass.library_defaults.KernelsForDataType">[docs]</a><span class="k">class</span> <span class="nc">KernelsForDataType</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Container class for keeping track of kernels that correspond to a particular combination</span>
|
||||
<span class="sd"> of data types for operands A, B, and accumulator</span>
|
||||
<span class="sd"> """</span>
|
||||
|
||||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">datatype_comb</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">,</span> <span class="n">layout_comb</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">datatype_comb</span> <span class="o">=</span> <span class="n">datatype_comb</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">layout_comb</span> <span class="o">=</span> <span class="n">layout_comb</span>
|
||||
|
||||
<span class="c1"># Dictionary mapping from alignment (int) to a list of kernels that fit the alignment</span>
|
||||
<span class="c1"># constraint for the data type combination</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">kernels_by_alignment</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
|
||||
<div class="viewcode-block" id="KernelsForDataType.add"><a class="viewcode-back" href="../../cutlass.html#cutlass.library_defaults.KernelsForDataType.add">[docs]</a> <span class="k">def</span> <span class="nf">add</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">operation</span><span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Add an operation to the list of supported kernels</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="n">alignment</span> <span class="o">=</span> <span class="n">operation</span><span class="o">.</span><span class="n">A</span><span class="o">.</span><span class="n">alignment</span>
|
||||
<span class="k">if</span> <span class="n">alignment</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernels_by_alignment</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">kernels_by_alignment</span><span class="p">[</span><span class="n">alignment</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">kernels_by_alignment</span><span class="p">[</span><span class="n">alignment</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">operation</span><span class="p">)</span></div>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">alignments</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Returns an unsorted list of alignments supported by this data type combination</span>
|
||||
|
||||
<span class="sd"> :return: unsorted list of alignments supported by this data type combination</span>
|
||||
<span class="sd"> :rtype: list</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">return</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">kernels_by_alignment</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">all_operations</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Returns a list of all operations supported by this data type combination</span>
|
||||
|
||||
<span class="sd"> :return: list of all operations supported by this data type combination</span>
|
||||
<span class="sd"> :rtype: list</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="n">ops</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="n">alignment_ops</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernels_by_alignment</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
|
||||
<span class="n">ops</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="n">alignment_ops</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">ops</span>
|
||||
|
||||
<div class="viewcode-block" id="KernelsForDataType.operations"><a class="viewcode-back" href="../../cutlass.html#cutlass.library_defaults.KernelsForDataType.operations">[docs]</a> <span class="k">def</span> <span class="nf">operations</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">alignment</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Returns operations satisfying the alignment constraint indicated by `alignment`</span>
|
||||
|
||||
<span class="sd"> :param alignment: alignment constraint of operations to return</span>
|
||||
<span class="sd"> :type alignment: int</span>
|
||||
|
||||
<span class="sd"> :return: list of operations</span>
|
||||
<span class="sd"> :rtype: list</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">if</span> <span class="n">alignment</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernels_by_alignment</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span>
|
||||
<span class="sa">f</span><span class="s2">"No operations of alignment </span><span class="si">{</span><span class="n">alignment</span><span class="si">}</span><span class="s2"> found for data type and layout "</span>
|
||||
<span class="sa">f</span><span class="s2">"combination </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">datatype_comb</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">layout_comb</span><span class="si">}</span><span class="s2">"</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernels_by_alignment</span><span class="p">[</span><span class="n">alignment</span><span class="p">]</span></div>
|
||||
|
||||
<div class="viewcode-block" id="KernelsForDataType.find_alignment"><a class="viewcode-back" href="../../cutlass.html#cutlass.library_defaults.KernelsForDataType.find_alignment">[docs]</a> <span class="k">def</span> <span class="nf">find_alignment</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">,</span> <span class="n">layout</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutType</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Returns the most preferable alignment for a given shape and layout</span>
|
||||
|
||||
<span class="sd"> :param shape: extent of each dimension of the tensor</span>
|
||||
<span class="sd"> :type shape: tuple</span>
|
||||
<span class="sd"> :param layout: layout of the tensor</span>
|
||||
<span class="sd"> :type layout: cutlass.LayoutType</span>
|
||||
|
||||
<span class="sd"> :return: maximum alignment supported by the data type combination and tensor size</span>
|
||||
<span class="sd"> :rtype: int</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="c1"># Determine the leading dimension of the shape</span>
|
||||
<span class="k">if</span> <span class="n">layout</span> <span class="o">==</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutType</span><span class="o">.</span><span class="n">RowMajor</span><span class="p">:</span>
|
||||
<span class="n">ld</span> <span class="o">=</span> <span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="k">elif</span> <span class="n">layout</span> <span class="o">==</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutType</span><span class="o">.</span><span class="n">RowMajor</span><span class="p">:</span>
|
||||
<span class="n">ld</span> <span class="o">=</span> <span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Unexpected or unsupported layout </span><span class="si">{</span><span class="n">layout</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">alignment</span> <span class="ow">in</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">kernels_by_alignment</span><span class="o">.</span><span class="n">keys</span><span class="p">()),</span> <span class="n">reverse</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="n">ld</span> <span class="o">%</span> <span class="n">alignment</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">alignment</span>
|
||||
|
||||
<span class="c1"># Default to alignment of 1 if no others match</span>
|
||||
<span class="k">return</span> <span class="mi">1</span></div>
|
||||
|
||||
<div class="viewcode-block" id="KernelsForDataType.sort"><a class="viewcode-back" href="../../cutlass.html#cutlass.library_defaults.KernelsForDataType.sort">[docs]</a> <span class="k">def</span> <span class="nf">sort</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Sorts each list of kernels in `kernels_by_alignment` in descending order of threadblock shape</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="n">key</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">op</span><span class="p">:</span> <span class="p">(</span>
|
||||
<span class="n">op</span><span class="o">.</span><span class="n">tile_description</span><span class="o">.</span><span class="n">threadblock_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="o">*</span> <span class="n">op</span><span class="o">.</span><span class="n">tile_description</span><span class="o">.</span><span class="n">threadblock_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
|
||||
<span class="o">*</span> <span class="n">op</span><span class="o">.</span><span class="n">tile_description</span><span class="o">.</span><span class="n">threadblock_shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">alignment</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernels_by_alignment</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">kernels_by_alignment</span><span class="p">[</span><span class="n">alignment</span><span class="p">]</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">key</span><span class="o">=</span><span class="n">key</span><span class="p">,</span> <span class="n">reverse</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></div></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="ArchOptions"><a class="viewcode-back" href="../../cutlass.html#cutlass.library_defaults.ArchOptions">[docs]</a><span class="k">class</span> <span class="nc">ArchOptions</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Structure for keeping track of kernels available on a given compute capability</span>
|
||||
|
||||
<span class="sd"> :param target_cc: compute capability of the device on which kernels will be run</span>
|
||||
<span class="sd"> :type target_cc: int</span>
|
||||
<span class="sd"> :param kernel_cc: compute capability of the kernels to generate</span>
|
||||
<span class="sd"> :type kernel_cc: int</span>
|
||||
<span class="sd"> :param operation_kind: type of operation to register</span>
|
||||
<span class="sd"> :type operation_kind: cutlass.OperationKind</span>
|
||||
<span class="sd"> :param gemm_kinds: types of GEMM operations that can be included</span>
|
||||
<span class="sd"> :type gemm_kinds: list</span>
|
||||
<span class="sd"> :param allowed_math_operations: types of primitive math operations allowed</span>
|
||||
<span class="sd"> :type allowed_math_operations: list</span>
|
||||
<span class="sd"> """</span>
|
||||
|
||||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="p">,</span>
|
||||
<span class="n">target_cc</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||||
<span class="n">kernel_cc</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||||
<span class="n">operation_kind</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">OperationKind</span><span class="p">,</span>
|
||||
<span class="n">gemm_kinds</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span>
|
||||
<span class="n">allowed_math_operations</span><span class="p">:</span> <span class="nb">list</span> <span class="o">=</span> <span class="p">[</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">MathOperation</span><span class="o">.</span><span class="n">multiply_add</span><span class="p">,</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">MathOperation</span><span class="o">.</span><span class="n">multiply_add_saturate</span><span class="p">,</span>
|
||||
<span class="p">]</span>
|
||||
<span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">cc</span> <span class="o">=</span> <span class="n">kernel_cc</span>
|
||||
|
||||
<span class="c1"># Dictionary with following structure:</span>
|
||||
<span class="c1"># Key: OpcodeClass</span>
|
||||
<span class="c1"># Value: Dictionary with the following structure:</span>
|
||||
<span class="c1"># Key: tuple of ((DataType, DataType, DataType), (LayoutType, LayoutType, LayoutType),</span>
|
||||
<span class="c1"># representing ((element_a, element_b, element_accumulator), (layout_a, layout_b))</span>
|
||||
<span class="c1"># Value: KernelsForDataType</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">operations_by_opclass</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">op_class</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">allowed_math_operations</span> <span class="o">=</span> <span class="n">allowed_math_operations</span>
|
||||
|
||||
<span class="c1"># Identify the method within CUTLASS generator script that generates kernel</span>
|
||||
<span class="c1"># descriptions for the target CC</span>
|
||||
<span class="n">generate_function_name</span> <span class="o">=</span> <span class="s2">"GenerateSM"</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">kernel_cc</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">prof_generator</span><span class="p">,</span> <span class="n">generate_function_name</span><span class="p">):</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="sa">f</span><span class="s2">"No generator found for architecture </span><span class="si">{</span><span class="n">kernel_cc</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||||
<span class="k">return</span>
|
||||
<span class="n">generate_function</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">prof_generator</span><span class="p">,</span> <span class="n">generate_function_name</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># Initialize a default manifest and populate it with valid kernel descriptions</span>
|
||||
<span class="c1"># for the target CC</span>
|
||||
<span class="n">args</span> <span class="o">=</span> <span class="p">[</span>
|
||||
<span class="s2">"--kernels=all"</span><span class="p">,</span>
|
||||
<span class="sa">f</span><span class="s2">"--log-level=</span><span class="si">{</span><span class="n">logging</span><span class="o">.</span><span class="n">getLevelName</span><span class="p">(</span><span class="n">cutlass</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">level</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span>
|
||||
<span class="p">]</span>
|
||||
<span class="n">manifest_args</span> <span class="o">=</span> <span class="n">prof_generator</span><span class="o">.</span><span class="n">define_parser</span><span class="p">()</span><span class="o">.</span><span class="n">parse_args</span><span class="p">(</span><span class="n">args</span><span class="p">)</span>
|
||||
<span class="n">manifest</span> <span class="o">=</span> <span class="n">prof_manifest</span><span class="o">.</span><span class="n">Manifest</span><span class="p">(</span><span class="n">manifest_args</span><span class="p">)</span>
|
||||
<span class="n">generate_function</span><span class="p">(</span><span class="n">manifest</span><span class="p">,</span> <span class="n">_cuda_version</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">operation_kind</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">manifest</span><span class="o">.</span><span class="n">operations</span><span class="p">:</span>
|
||||
<span class="c1"># No kernels generated for this architecture, this could be because the CUDA</span>
|
||||
<span class="c1"># toolkit is insufficient to support operations in this CC</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="sa">f</span><span class="s2">"No operations of type </span><span class="si">{</span><span class="n">operation_kind</span><span class="si">}</span><span class="s2"> found for CC </span><span class="si">{</span><span class="n">kernel_cc</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||||
<span class="k">return</span>
|
||||
|
||||
<span class="c1"># Iterate through the available operations for this operation kind and</span>
|
||||
<span class="c1"># find available opclasses and data types</span>
|
||||
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">op_list</span> <span class="ow">in</span> <span class="n">manifest</span><span class="o">.</span><span class="n">operations</span><span class="p">[</span><span class="n">operation_kind</span><span class="p">]</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
|
||||
<span class="k">for</span> <span class="n">op</span> <span class="ow">in</span> <span class="n">op_list</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">op</span><span class="o">.</span><span class="n">gemm_kind</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">gemm_kinds</span><span class="p">:</span>
|
||||
<span class="k">continue</span>
|
||||
|
||||
<span class="n">mi</span> <span class="o">=</span> <span class="n">op</span><span class="o">.</span><span class="n">tile_description</span><span class="o">.</span><span class="n">math_instruction</span>
|
||||
<span class="k">if</span> <span class="n">mi</span><span class="o">.</span><span class="n">math_operation</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">allowed_math_operations</span><span class="p">:</span>
|
||||
<span class="k">continue</span>
|
||||
|
||||
<span class="n">datatype_comb</span> <span class="o">=</span> <span class="p">(</span><span class="n">mi</span><span class="o">.</span><span class="n">element_a</span><span class="p">,</span> <span class="n">mi</span><span class="o">.</span><span class="n">element_b</span><span class="p">,</span> <span class="n">mi</span><span class="o">.</span><span class="n">element_accumulator</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># Skip any data types that do not currently have conversions via cutlass_bindings</span>
|
||||
<span class="k">if</span> <span class="kc">False</span> <span class="ow">in</span> <span class="p">[</span><span class="n">has_binding_type</span><span class="p">(</span><span class="n">elt</span><span class="p">)</span> <span class="k">for</span> <span class="n">elt</span> <span class="ow">in</span> <span class="n">datatype_comb</span><span class="p">]:</span>
|
||||
<span class="k">continue</span>
|
||||
|
||||
<span class="c1"># Prune operations that don't fit in shared memory</span>
|
||||
<span class="n">td</span> <span class="o">=</span> <span class="n">td_from_profiler_op</span><span class="p">(</span><span class="n">op</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="n">valid_stage_count</span><span class="p">(</span><span class="n">target_cc</span><span class="p">,</span> <span class="n">td</span><span class="p">)[</span><span class="mi">0</span><span class="p">]:</span>
|
||||
<span class="k">continue</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">mi</span><span class="o">.</span><span class="n">opcode_class</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">operations_by_opclass</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">operations_by_opclass</span><span class="p">[</span><span class="n">mi</span><span class="o">.</span><span class="n">opcode_class</span><span class="p">]</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
|
||||
<span class="n">datatype_comb</span> <span class="o">=</span> <span class="p">(</span><span class="n">mi</span><span class="o">.</span><span class="n">element_a</span><span class="p">,</span> <span class="n">mi</span><span class="o">.</span><span class="n">element_b</span><span class="p">,</span> <span class="n">mi</span><span class="o">.</span><span class="n">element_accumulator</span><span class="p">)</span>
|
||||
<span class="n">layout_comb</span> <span class="o">=</span> <span class="p">(</span><span class="n">op</span><span class="o">.</span><span class="n">A</span><span class="o">.</span><span class="n">layout</span><span class="p">,</span> <span class="n">op</span><span class="o">.</span><span class="n">B</span><span class="o">.</span><span class="n">layout</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># Register TF32 kernels as F32 to enable F32 -> TF32 conversion + TF32 Tensor Core operations</span>
|
||||
<span class="k">if</span> <span class="n">datatype_comb</span> <span class="o">==</span> <span class="p">(</span><span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">tf32</span><span class="p">,</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">tf32</span><span class="p">,</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f32</span><span class="p">):</span>
|
||||
<span class="c1"># TF32 kernels only supported on SM80 and beyond</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cc</span> <span class="o"><</span> <span class="mi">80</span><span class="p">:</span>
|
||||
<span class="k">continue</span>
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">cc</span> <span class="o">==</span> <span class="mi">90</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="p">(</span><span class="n">op</span><span class="o">.</span><span class="n">A</span><span class="o">.</span><span class="n">element</span> <span class="o">!=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f32</span>
|
||||
<span class="ow">or</span> <span class="n">op</span><span class="o">.</span><span class="n">B</span><span class="o">.</span><span class="n">element</span> <span class="o">!=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f32</span>
|
||||
<span class="ow">or</span> <span class="n">op</span><span class="o">.</span><span class="n">C</span><span class="o">.</span><span class="n">element</span> <span class="o">!=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f32</span><span class="p">):</span>
|
||||
<span class="k">continue</span>
|
||||
|
||||
<span class="n">datatype_comb</span> <span class="o">=</span> <span class="p">(</span><span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f32</span><span class="p">,</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f32</span><span class="p">,</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f32</span><span class="p">)</span>
|
||||
|
||||
<span class="n">opclass_dict</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">operations_by_opclass</span><span class="p">[</span><span class="n">mi</span><span class="o">.</span><span class="n">opcode_class</span><span class="p">]</span>
|
||||
<span class="n">key</span> <span class="o">=</span> <span class="p">(</span><span class="n">datatype_comb</span><span class="p">,</span> <span class="n">layout_comb</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">opclass_dict</span><span class="p">:</span>
|
||||
<span class="n">opclass_dict</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">KernelsForDataType</span><span class="p">(</span><span class="n">datatype_comb</span><span class="p">,</span> <span class="n">layout_comb</span><span class="p">)</span>
|
||||
<span class="n">opclass_dict</span><span class="p">[</span><span class="n">key</span><span class="p">]</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">op</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># Set the default opclass to TensorOp, if available. Otherwise default to SIMT</span>
|
||||
<span class="k">if</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">OpcodeClass</span><span class="o">.</span><span class="n">TensorOp</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">operations_by_opclass</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">op_class</span> <span class="o">=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">OpcodeClass</span><span class="o">.</span><span class="n">TensorOp</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">op_class</span> <span class="o">=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">OpcodeClass</span><span class="o">.</span><span class="n">Simt</span>
|
||||
|
||||
<span class="c1"># The profiler's generator may generate only a limited set of combinations of operands for SIMT kernels.</span>
|
||||
<span class="c1"># Here, we generate additional versions via a generic TileDescription.</span>
|
||||
<span class="k">if</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">OpcodeClass</span><span class="o">.</span><span class="n">Simt</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">operations_by_opclass</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">operations_by_opclass</span><span class="p">[</span><span class="n">cutlass</span><span class="o">.</span><span class="n">OpcodeClass</span><span class="o">.</span><span class="n">Simt</span><span class="p">]</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
|
||||
<span class="n">types</span> <span class="o">=</span> <span class="p">[</span>
|
||||
<span class="p">(</span><span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">s8</span><span class="p">,</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">s8</span><span class="p">,</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">s8</span><span class="p">),</span>
|
||||
<span class="p">(</span><span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">s8</span><span class="p">,</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">s8</span><span class="p">,</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">s32</span><span class="p">),</span>
|
||||
<span class="p">(</span><span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f16</span><span class="p">,</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f16</span><span class="p">,</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f16</span><span class="p">),</span>
|
||||
<span class="p">(</span><span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f16</span><span class="p">,</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f16</span><span class="p">,</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f32</span><span class="p">),</span>
|
||||
<span class="p">(</span><span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f32</span><span class="p">,</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f32</span><span class="p">,</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f32</span><span class="p">),</span>
|
||||
<span class="p">(</span><span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f64</span><span class="p">,</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f64</span><span class="p">,</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f64</span><span class="p">),</span>
|
||||
<span class="p">]</span>
|
||||
|
||||
<span class="n">layouts</span> <span class="o">=</span> <span class="p">[</span>
|
||||
<span class="p">(</span><span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutType</span><span class="o">.</span><span class="n">RowMajor</span><span class="p">,</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutType</span><span class="o">.</span><span class="n">RowMajor</span><span class="p">),</span>
|
||||
<span class="p">(</span><span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutType</span><span class="o">.</span><span class="n">RowMajor</span><span class="p">,</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutType</span><span class="o">.</span><span class="n">ColumnMajor</span><span class="p">),</span>
|
||||
<span class="p">(</span><span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutType</span><span class="o">.</span><span class="n">ColumnMajor</span><span class="p">,</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutType</span><span class="o">.</span><span class="n">RowMajor</span><span class="p">),</span>
|
||||
<span class="p">(</span><span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutType</span><span class="o">.</span><span class="n">ColumnMajor</span><span class="p">,</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutType</span><span class="o">.</span><span class="n">ColumnMajor</span><span class="p">),</span>
|
||||
<span class="p">]</span>
|
||||
<span class="n">alignment</span> <span class="o">=</span> <span class="mi">1</span>
|
||||
<span class="n">epilogue_functor</span> <span class="o">=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">EpilogueFunctor</span><span class="o">.</span><span class="n">LinearCombination</span>
|
||||
<span class="n">swizzling_functor</span> <span class="o">=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">SwizzlingFunctor</span><span class="o">.</span><span class="n">Identity8</span>
|
||||
<span class="k">for</span> <span class="n">type_comb</span> <span class="ow">in</span> <span class="n">types</span><span class="p">:</span>
|
||||
<span class="k">for</span> <span class="n">layout_comb</span> <span class="ow">in</span> <span class="n">layouts</span><span class="p">:</span>
|
||||
<span class="n">comb</span> <span class="o">=</span> <span class="p">(</span><span class="n">type_comb</span><span class="p">,</span> <span class="n">layout_comb</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">comb</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">operations_by_opclass</span><span class="p">[</span><span class="n">cutlass</span><span class="o">.</span><span class="n">OpcodeClass</span><span class="o">.</span><span class="n">Simt</span><span class="p">]:</span>
|
||||
<span class="k">continue</span>
|
||||
|
||||
<span class="n">A</span> <span class="o">=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">TensorDescription</span><span class="p">(</span><span class="n">type_comb</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">layout_comb</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">alignment</span><span class="p">)</span>
|
||||
<span class="n">B</span> <span class="o">=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">TensorDescription</span><span class="p">(</span><span class="n">type_comb</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">layout_comb</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">alignment</span><span class="p">)</span>
|
||||
<span class="n">C</span> <span class="o">=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">TensorDescription</span><span class="p">(</span><span class="n">type_comb</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutType</span><span class="o">.</span><span class="n">ColumnMajor</span><span class="p">,</span> <span class="n">alignment</span><span class="p">)</span>
|
||||
<span class="n">math_inst</span> <span class="o">=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">MathInstruction</span><span class="p">(</span>
|
||||
<span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span>
|
||||
<span class="n">type_comb</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
|
||||
<span class="n">type_comb</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span>
|
||||
<span class="n">type_comb</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">OpcodeClass</span><span class="o">.</span><span class="n">Simt</span><span class="p">,</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">MathOperation</span><span class="o">.</span><span class="n">multiply_add</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="n">td</span> <span class="o">=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">TileDescription</span><span class="p">(</span>
|
||||
<span class="p">[</span><span class="mi">128</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="mi">8</span><span class="p">],</span> <span class="mi">2</span><span class="p">,</span> <span class="p">[</span><span class="mi">4</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">math_inst</span><span class="p">,</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">1024</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># Prune operations that don't fit in shared memory</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="n">valid_stage_count</span><span class="p">(</span><span class="n">target_cc</span><span class="p">,</span> <span class="n">td_from_profiler_td</span><span class="p">(</span><span class="n">td</span><span class="p">))[</span><span class="mi">0</span><span class="p">]:</span>
|
||||
<span class="k">continue</span>
|
||||
|
||||
<span class="n">new_operation</span> <span class="o">=</span> <span class="n">prof_manifest</span><span class="o">.</span><span class="n">GemmOperation</span><span class="p">(</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">GemmKind</span><span class="o">.</span><span class="n">Universal</span><span class="p">,</span> <span class="n">td</span><span class="o">.</span><span class="n">minimum_compute_capability</span><span class="p">,</span>
|
||||
<span class="n">td</span><span class="p">,</span> <span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">type_comb</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">epilogue_functor</span><span class="p">,</span> <span class="n">swizzling_functor</span><span class="p">)</span>
|
||||
|
||||
<span class="n">new_kernels</span> <span class="o">=</span> <span class="n">KernelsForDataType</span><span class="p">(</span><span class="n">type_comb</span><span class="p">,</span> <span class="n">layout_comb</span><span class="p">)</span>
|
||||
<span class="n">new_kernels</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">new_operation</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">operations_by_opclass</span><span class="p">[</span><span class="n">cutlass</span><span class="o">.</span><span class="n">OpcodeClass</span><span class="o">.</span><span class="n">Simt</span><span class="p">][</span><span class="n">comb</span><span class="p">]</span> <span class="o">=</span> <span class="n">new_kernels</span>
|
||||
|
||||
<span class="c1"># Sort all operations</span>
|
||||
<span class="k">for</span> <span class="n">oc</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">operations_by_opclass</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
|
||||
<span class="k">for</span> <span class="n">comb</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">operations_by_opclass</span><span class="p">[</span><span class="n">oc</span><span class="p">]</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">operations_by_opclass</span><span class="p">[</span><span class="n">oc</span><span class="p">][</span><span class="n">comb</span><span class="p">]</span><span class="o">.</span><span class="n">sort</span><span class="p">()</span>
|
||||
|
||||
<div class="viewcode-block" id="ArchOptions.opclass_supports_combination"><a class="viewcode-back" href="../../cutlass.html#cutlass.library_defaults.ArchOptions.opclass_supports_combination">[docs]</a> <span class="k">def</span> <span class="nf">opclass_supports_combination</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="p">,</span> <span class="n">op_class</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">OpcodeClass</span><span class="p">,</span> <span class="n">datatype_comb</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">,</span> <span class="n">layout_comb</span><span class="p">:</span> <span class="nb">tuple</span>
|
||||
<span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Returns whether the provided operation class supports the provided data type and layout combination</span>
|
||||
|
||||
<span class="sd"> :param op_class: operation class to consider</span>
|
||||
<span class="sd"> :type op_class: cutlass.OpcodeClass</span>
|
||||
<span class="sd"> :param datatype_comb: tuple of data types for (element_A, element_B, element_accumulator)</span>
|
||||
<span class="sd"> :type datatype_comb: tuple[cutlass.DataType]</span>
|
||||
<span class="sd"> :param layout_comb: tuple of data types for (layout_A, layout_B)</span>
|
||||
<span class="sd"> :type layout_comb: tuple[cutlass.LayoutType]</span>
|
||||
|
||||
<span class="sd"> :return: set of operation classes that support the provided data type and layout combination</span>
|
||||
<span class="sd"> :rtype: set</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">if</span> <span class="n">op_class</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">operations_by_opclass</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Unexpected or unsupported operation class </span><span class="si">{</span><span class="n">op_class</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||||
|
||||
<span class="k">return</span> <span class="p">(</span><span class="n">datatype_comb</span><span class="p">,</span> <span class="n">layout_comb</span><span class="p">)</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">operations_by_opclass</span><span class="p">[</span><span class="n">op_class</span><span class="p">]</span></div>
|
||||
|
||||
<div class="viewcode-block" id="ArchOptions.supporting_opclasses"><a class="viewcode-back" href="../../cutlass.html#cutlass.library_defaults.ArchOptions.supporting_opclasses">[docs]</a> <span class="k">def</span> <span class="nf">supporting_opclasses</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="p">,</span>
|
||||
<span class="n">element_a</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="p">,</span>
|
||||
<span class="n">element_b</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="p">,</span>
|
||||
<span class="n">element_accumulator</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="p">,</span>
|
||||
<span class="n">layout_a</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutType</span><span class="p">,</span>
|
||||
<span class="n">layout_b</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutType</span><span class="p">,</span>
|
||||
<span class="p">)</span> <span class="o">-></span> <span class="nb">set</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Returns a set of operation classes that support the provided data type combination</span>
|
||||
|
||||
<span class="sd"> :param element_a: data type of operand A</span>
|
||||
<span class="sd"> :type element_a: cutlass.DataType</span>
|
||||
<span class="sd"> :param element_b: data type of operand B</span>
|
||||
<span class="sd"> :type element_b: cutlass.DataType</span>
|
||||
<span class="sd"> :param element_accumulator: data type of accumulator</span>
|
||||
<span class="sd"> :type element_accumulator: cutlass.DataType</span>
|
||||
<span class="sd"> :param layout_a: layout of operand A</span>
|
||||
<span class="sd"> :type layout_a: cutlass.LayoutType</span>
|
||||
<span class="sd"> :param layout_b: layout of operand B</span>
|
||||
<span class="sd"> :type layout_b: cutlass.LayoutType</span>
|
||||
|
||||
<span class="sd"> :return: set of operation classes that support the provided data type combination</span>
|
||||
<span class="sd"> :rtype: set</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="n">supporting_op_classes</span> <span class="o">=</span> <span class="nb">set</span><span class="p">()</span>
|
||||
<span class="n">datatype_comb</span> <span class="o">=</span> <span class="p">(</span><span class="n">element_a</span><span class="p">,</span> <span class="n">element_b</span><span class="p">,</span> <span class="n">element_accumulator</span><span class="p">)</span>
|
||||
<span class="n">layout_comb</span> <span class="o">=</span> <span class="p">(</span><span class="n">layout_a</span><span class="p">,</span> <span class="n">layout_b</span><span class="p">)</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">op_class</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">operations_by_opclass</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">opclass_supports_combination</span><span class="p">(</span><span class="n">op_class</span><span class="p">,</span> <span class="n">datatype_comb</span><span class="p">,</span> <span class="n">layout_comb</span><span class="p">):</span>
|
||||
<span class="n">supporting_op_classes</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">op_class</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">supporting_op_classes</span></div>
|
||||
|
||||
<div class="viewcode-block" id="ArchOptions.operations"><a class="viewcode-back" href="../../cutlass.html#cutlass.library_defaults.ArchOptions.operations">[docs]</a> <span class="k">def</span> <span class="nf">operations</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="p">,</span>
|
||||
<span class="n">op_class</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">OpcodeClass</span><span class="p">,</span>
|
||||
<span class="n">element_a</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="p">,</span>
|
||||
<span class="n">element_b</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="p">,</span>
|
||||
<span class="n">element_accumulator</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="p">,</span>
|
||||
<span class="n">layout_a</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutType</span><span class="p">,</span>
|
||||
<span class="n">layout_b</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutType</span><span class="p">,</span>
|
||||
<span class="p">)</span> <span class="o">-></span> <span class="n">KernelsForDataType</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Returns whether the provided operation class supports the provided data type combination</span>
|
||||
|
||||
<span class="sd"> :param op_class: operation class to consider</span>
|
||||
<span class="sd"> :type op_class: cutlass.OpcodeClass</span>
|
||||
<span class="sd"> :param element_a: data type of operand A</span>
|
||||
<span class="sd"> :type element_a: cutlass.DataType</span>
|
||||
<span class="sd"> :param element_b: data type of operand B</span>
|
||||
<span class="sd"> :type element_b: cutlass.DataType</span>
|
||||
<span class="sd"> :param element_accumulator: data type of accumulator</span>
|
||||
<span class="sd"> :type element_accumulator: cutlass.DataType</span>
|
||||
<span class="sd"> :param layout_a: layout of operand A</span>
|
||||
<span class="sd"> :type layout_a: cutlass.LayoutType</span>
|
||||
<span class="sd"> :param layout_b: layout of operand B</span>
|
||||
<span class="sd"> :type layout_b: cutlass.LayoutType</span>
|
||||
|
||||
<span class="sd"> :return: container of kernels by alignment supported by the provided combination of parameters</span>
|
||||
<span class="sd"> :rtype: KernelsForDataType</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="n">datatype_comb</span> <span class="o">=</span> <span class="p">(</span><span class="n">element_a</span><span class="p">,</span> <span class="n">element_b</span><span class="p">,</span> <span class="n">element_accumulator</span><span class="p">)</span>
|
||||
<span class="n">layout_comb</span> <span class="o">=</span> <span class="p">(</span><span class="n">layout_a</span><span class="p">,</span> <span class="n">layout_b</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">opclass_supports_combination</span><span class="p">(</span><span class="n">op_class</span><span class="p">,</span> <span class="n">datatype_comb</span><span class="p">,</span> <span class="n">layout_comb</span><span class="p">):</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span>
|
||||
<span class="sa">f</span><span class="s2">"Data type layout combination </span><span class="si">{</span><span class="n">datatype_comb</span><span class="si">}</span><span class="s2">, </span><span class="si">{</span><span class="n">layout_comb</span><span class="si">}</span><span class="s2"> "</span>
|
||||
<span class="sa">f</span><span class="s2">"is not supported by opcode class </span><span class="si">{</span><span class="n">op_class</span><span class="si">}</span><span class="s2"> on CC </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">cc</span><span class="si">}</span><span class="s2">."</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">operations_by_opclass</span><span class="p">[</span><span class="n">op_class</span><span class="p">][(</span><span class="n">datatype_comb</span><span class="p">,</span> <span class="n">layout_comb</span><span class="p">)]</span></div></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="OptionRegistry"><a class="viewcode-back" href="../../cutlass.html#cutlass.library_defaults.OptionRegistry">[docs]</a><span class="k">class</span> <span class="nc">OptionRegistry</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Container of all architecture-specific options</span>
|
||||
|
||||
<span class="sd"> :param target_cc: compute capability of the device on which operations will be run</span>
|
||||
<span class="sd"> :type target_cc: int</span>
|
||||
<span class="sd"> """</span>
|
||||
|
||||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">target_cc</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">registry</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
|
||||
<span class="n">gemm_kinds</span> <span class="o">=</span> <span class="p">[</span><span class="n">cutlass</span><span class="o">.</span><span class="n">GemmKind</span><span class="o">.</span><span class="n">Universal</span><span class="p">,</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">GemmKind</span><span class="o">.</span><span class="n">Universal3x</span><span class="p">]</span>
|
||||
<span class="c1"># Construct options for each CC</span>
|
||||
<span class="k">for</span> <span class="n">kernel_cc</span> <span class="ow">in</span> <span class="n">_generator_ccs</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">registry</span><span class="p">[</span><span class="n">kernel_cc</span><span class="p">]</span> <span class="o">=</span> <span class="n">ArchOptions</span><span class="p">(</span><span class="n">target_cc</span><span class="p">,</span> <span class="n">kernel_cc</span><span class="p">,</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">OperationKind</span><span class="o">.</span><span class="n">Gemm</span><span class="p">,</span> <span class="n">gemm_kinds</span><span class="p">)</span>
|
||||
|
||||
<div class="viewcode-block" id="OptionRegistry.options_for_cc"><a class="viewcode-back" href="../../cutlass.html#cutlass.library_defaults.OptionRegistry.options_for_cc">[docs]</a> <span class="k">def</span> <span class="nf">options_for_cc</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">cc</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">ArchOptions</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">registry</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">cc</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span></div></div>
|
||||
</pre></div>
|
||||
</article>
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
<div class="related-pages">
|
||||
|
||||
|
||||
</div>
|
||||
<div class="bottom-of-page">
|
||||
<div class="left-details">
|
||||
<div class="copyright">
|
||||
Copyright © 2023, NVIDIA
|
||||
</div>
|
||||
Made with <a href="https://www.sphinx-doc.org/">Sphinx</a> and <a class="muted-link" href="https://pradyunsg.me">@pradyunsg</a>'s
|
||||
|
||||
<a href="https://github.com/pradyunsg/furo">Furo</a>
|
||||
|
||||
</div>
|
||||
<div class="right-details">
|
||||
<div class="icons">
|
||||
<a class="muted-link " href="https://github.com/NVIDIA/cutlass" aria-label="GitHub">
|
||||
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 16 16">
|
||||
<path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z"></path>
|
||||
</svg>
|
||||
</a>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</footer>
|
||||
</div>
|
||||
<aside class="toc-drawer no-toc">
|
||||
|
||||
|
||||
|
||||
</aside>
|
||||
</div>
|
||||
</div><script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
<script src="../../_static/sphinx_highlight.js"></script>
|
||||
<script src="../../_static/scripts/furo.js"></script>
|
||||
<script src="../../_static/clipboard.min.js"></script>
|
||||
<script src="../../_static/copybutton.js"></script>
|
||||
<script src="../../_static/tabs.js"></script>
|
||||
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
980
python/docs/_modules/cutlass/op/gemm.html
Normal file
980
python/docs/_modules/cutlass/op/gemm.html
Normal file
@ -0,0 +1,980 @@
|
||||
<!doctype html>
|
||||
<html class="no-js" lang="en">
|
||||
<head><meta charset="utf-8"/>
|
||||
<meta name="viewport" content="width=device-width,initial-scale=1"/>
|
||||
<meta name="color-scheme" content="light dark"><link rel="index" title="Index" href="../../../genindex.html" /><link rel="search" title="Search" href="../../../search.html" />
|
||||
<link rel="canonical" href="docs/_modules/cutlass/op/gemm.html" />
|
||||
|
||||
<!-- Generated with Sphinx 6.1.3 and Furo 2023.03.27 -->
|
||||
<title>cutlass.op.gemm - CUTLASS Python</title>
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/pygments.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/styles/furo.css?digest=fad236701ea90a88636c2a8c73b44ae642ed2a53" />
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/copybutton.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/tabs.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/styles/furo-extensions.css?digest=30d1aed668e5c3a91c3e3bf6a60b675221979f0e" />
|
||||
|
||||
|
||||
|
||||
|
||||
<style>
|
||||
body {
|
||||
--color-code-background: #eeffcc;
|
||||
--color-code-foreground: black;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
@media not print {
|
||||
body[data-theme="dark"] {
|
||||
--color-code-background: #272822;
|
||||
--color-code-foreground: #f8f8f2;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
@media (prefers-color-scheme: dark) {
|
||||
body:not([data-theme="light"]) {
|
||||
--color-code-background: #272822;
|
||||
--color-code-foreground: #f8f8f2;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
</style></head>
|
||||
<body>
|
||||
|
||||
<script>
|
||||
document.body.dataset.theme = localStorage.getItem("theme") || "auto";
|
||||
</script>
|
||||
|
||||
|
||||
<svg xmlns="http://www.w3.org/2000/svg" style="display: none;">
|
||||
<symbol id="svg-toc" viewBox="0 0 24 24">
|
||||
<title>Contents</title>
|
||||
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 1024 1024">
|
||||
<path d="M408 442h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8zm-8 204c0 4.4 3.6 8 8 8h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56zm504-486H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zm0 632H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zM115.4 518.9L271.7 642c5.8 4.6 14.4.5 14.4-6.9V388.9c0-7.4-8.5-11.5-14.4-6.9L115.4 505.1a8.74 8.74 0 0 0 0 13.8z"/>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-menu" viewBox="0 0 24 24">
|
||||
<title>Menu</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-menu">
|
||||
<line x1="3" y1="12" x2="21" y2="12"></line>
|
||||
<line x1="3" y1="6" x2="21" y2="6"></line>
|
||||
<line x1="3" y1="18" x2="21" y2="18"></line>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-arrow-right" viewBox="0 0 24 24">
|
||||
<title>Expand</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-chevron-right">
|
||||
<polyline points="9 18 15 12 9 6"></polyline>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-sun" viewBox="0 0 24 24">
|
||||
<title>Light mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather-sun">
|
||||
<circle cx="12" cy="12" r="5"></circle>
|
||||
<line x1="12" y1="1" x2="12" y2="3"></line>
|
||||
<line x1="12" y1="21" x2="12" y2="23"></line>
|
||||
<line x1="4.22" y1="4.22" x2="5.64" y2="5.64"></line>
|
||||
<line x1="18.36" y1="18.36" x2="19.78" y2="19.78"></line>
|
||||
<line x1="1" y1="12" x2="3" y2="12"></line>
|
||||
<line x1="21" y1="12" x2="23" y2="12"></line>
|
||||
<line x1="4.22" y1="19.78" x2="5.64" y2="18.36"></line>
|
||||
<line x1="18.36" y1="5.64" x2="19.78" y2="4.22"></line>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-moon" viewBox="0 0 24 24">
|
||||
<title>Dark mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-moon">
|
||||
<path stroke="none" d="M0 0h24v24H0z" fill="none" />
|
||||
<path d="M12 3c.132 0 .263 0 .393 0a7.5 7.5 0 0 0 7.92 12.446a9 9 0 1 1 -8.313 -12.454z" />
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-sun-half" viewBox="0 0 24 24">
|
||||
<title>Auto light/dark mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-shadow">
|
||||
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
|
||||
<circle cx="12" cy="12" r="9" />
|
||||
<path d="M13 12h5" />
|
||||
<path d="M13 15h4" />
|
||||
<path d="M13 18h1" />
|
||||
<path d="M13 9h4" />
|
||||
<path d="M13 6h1" />
|
||||
</svg>
|
||||
</symbol>
|
||||
</svg>
|
||||
|
||||
<input type="checkbox" class="sidebar-toggle" name="__navigation" id="__navigation">
|
||||
<input type="checkbox" class="sidebar-toggle" name="__toc" id="__toc">
|
||||
<label class="overlay sidebar-overlay" for="__navigation">
|
||||
<div class="visually-hidden">Hide navigation sidebar</div>
|
||||
</label>
|
||||
<label class="overlay toc-overlay" for="__toc">
|
||||
<div class="visually-hidden">Hide table of contents sidebar</div>
|
||||
</label>
|
||||
|
||||
|
||||
|
||||
<div class="page">
|
||||
<header class="mobile-header">
|
||||
<div class="header-left">
|
||||
<label class="nav-overlay-icon" for="__navigation">
|
||||
<div class="visually-hidden">Toggle site navigation sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-menu"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
<div class="header-center">
|
||||
<a href="../../../index.html"><div class="brand">CUTLASS Python</div></a>
|
||||
</div>
|
||||
<div class="header-right">
|
||||
<div class="theme-toggle-container theme-toggle-header">
|
||||
<button class="theme-toggle">
|
||||
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
|
||||
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
|
||||
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
|
||||
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
|
||||
</button>
|
||||
</div>
|
||||
<label class="toc-overlay-icon toc-header-icon no-toc" for="__toc">
|
||||
<div class="visually-hidden">Toggle table of contents sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
</header>
|
||||
<aside class="sidebar-drawer">
|
||||
<div class="sidebar-container">
|
||||
|
||||
<div class="sidebar-sticky"><a class="sidebar-brand" href="../../../index.html">
|
||||
|
||||
<div class="sidebar-logo-container">
|
||||
<img class="sidebar-logo only-light" src="../../../_static/cutlass-logo-small.png" alt="Light Logo"/>
|
||||
<img class="sidebar-logo only-dark" src="../../../_static/cutlass-logo-small.png" alt="Dark Logo"/>
|
||||
</div>
|
||||
|
||||
<span class="sidebar-brand-text">CUTLASS Python</span>
|
||||
|
||||
</a><form class="sidebar-search-container" method="get" action="../../../search.html" role="search">
|
||||
<input class="sidebar-search" placeholder="Search" name="q" aria-label="Search">
|
||||
<input type="hidden" name="check_keywords" value="yes">
|
||||
<input type="hidden" name="area" value="default">
|
||||
</form>
|
||||
<div id="searchbox"></div><div class="sidebar-scroll"><div class="sidebar-tree">
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../index.html">Home</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../install.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../externals/00_basic_gemm.html">Getting Started</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contribute.html">Contributing</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python Documentation:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../modules.html">CUTLASS Python API</a><input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" role="switch" type="checkbox"/><label for="toctree-checkbox-1"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l2 has-children"><a class="reference internal" href="../../../cutlass.html">CUTLASS</a><input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" role="switch" type="checkbox"/><label for="toctree-checkbox-2"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../../../cutlass.emit.html">Emitters</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../../../cutlass.op.html">Operations</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../../../cutlass.utils.html">Utilities</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Examples and Tutorials:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../examples.html">Examples</a><input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" role="switch" type="checkbox"/><label for="toctree-checkbox-3"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../../../externals/00_basic_gemm.html">Basic GEMM</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../../../externals/01_epilogue.html">Epilogue</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../../../externals/02_pytorch_extension_grouped_gemm.html">PyTorch Extension</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Reference:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/cutlass">Github</a></li>
|
||||
</ul>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</aside>
|
||||
<div class="main">
|
||||
<div class="content">
|
||||
<div class="article-container">
|
||||
<a href="#" class="back-to-top muted-link">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24">
|
||||
<path d="M13 20h-2V8l-5.5 5.5-1.42-1.42L12 4.16l7.92 7.92-1.42 1.42L13 8v12z"></path>
|
||||
</svg>
|
||||
<span>Back to top</span>
|
||||
</a>
|
||||
<div class="content-icon-container">
|
||||
<div class="theme-toggle-container theme-toggle-content">
|
||||
<button class="theme-toggle">
|
||||
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
|
||||
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
|
||||
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
|
||||
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
|
||||
</button>
|
||||
</div>
|
||||
<label class="toc-overlay-icon toc-content-icon no-toc" for="__toc">
|
||||
<div class="visually-hidden">Toggle table of contents sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
<article role="main">
|
||||
<h1>Source code for cutlass.op.gemm</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#################################################################################################</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.</span>
|
||||
<span class="c1"># SPDX-License-Identifier: BSD-3-Clause</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Redistribution and use in source and binary forms, with or without</span>
|
||||
<span class="c1"># modification, are permitted provided that the following conditions are met:</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># 1. Redistributions of source code must retain the above copyright notice, this</span>
|
||||
<span class="c1"># list of conditions and the following disclaimer.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># 2. Redistributions in binary form must reproduce the above copyright notice,</span>
|
||||
<span class="c1"># this list of conditions and the following disclaimer in the documentation</span>
|
||||
<span class="c1"># and/or other materials provided with the distribution.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># 3. Neither the name of the copyright holder nor the names of its</span>
|
||||
<span class="c1"># contributors may be used to endorse or promote products derived from</span>
|
||||
<span class="c1"># this software without specific prior written permission.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"</span>
|
||||
<span class="c1"># AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE</span>
|
||||
<span class="c1"># IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE</span>
|
||||
<span class="c1"># DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE</span>
|
||||
<span class="c1"># FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL</span>
|
||||
<span class="c1"># DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR</span>
|
||||
<span class="c1"># SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER</span>
|
||||
<span class="c1"># CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,</span>
|
||||
<span class="c1"># OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE</span>
|
||||
<span class="c1"># OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1">#################################################################################################</span>
|
||||
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Ease-of-use interface for constructing, compiling, and running GEMMs.</span>
|
||||
|
||||
<span class="sd"> The ``Gemm`` interface is meant to allow one to easily instantiate, compile, and run</span>
|
||||
<span class="sd"> GEMM operations in CUTLASS via Python, without specifying many configuration parameters.</span>
|
||||
<span class="sd"> Under the hood, the interface will select sensible default parameters for the many template</span>
|
||||
<span class="sd"> parameters for CUTLASS GEMMs.</span>
|
||||
|
||||
<span class="sd"> Note: optimal performance is not to be expected from this interface. To achieve optimal</span>
|
||||
<span class="sd"> performance, one should specify and tune each configuration parameter.</span>
|
||||
|
||||
<span class="sd"> The simplest example of using this interface is the following:</span>
|
||||
|
||||
<span class="sd"> .. highlight:: python</span>
|
||||
<span class="sd"> .. code-block:: python</span>
|
||||
|
||||
<span class="sd"> # A, B, C, and D are torch/numpy/cupy tensor objects</span>
|
||||
<span class="sd"> plan = cutlass.op.Gemm(A, B, C, D)</span>
|
||||
<span class="sd"> plan.run()</span>
|
||||
|
||||
|
||||
<span class="sd"> One can also use the interface by specifying data types of operands at construction</span>
|
||||
<span class="sd"> and using different tensor objects with these data types at runtime:</span>
|
||||
|
||||
<span class="sd"> .. highlight:: python</span>
|
||||
<span class="sd"> .. code-block:: python</span>
|
||||
|
||||
<span class="sd"> # The following is shorthand for:</span>
|
||||
<span class="sd"> # cutlass.op.Gemm(element_A=torch.float32, element_B=torch.float32,</span>
|
||||
<span class="sd"> # element_C=torch.float32, element_D=torch.float32,</span>
|
||||
<span class="sd"> # element_accumulator=torch.float32,</span>
|
||||
<span class="sd"> # layout=cutlass.LayoutType.RowMajor)</span>
|
||||
<span class="sd"> plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass.LayoutType.RowMajor)</span>
|
||||
|
||||
<span class="sd"> A0 = torch.rand((128, 256), device='cuda')</span>
|
||||
<span class="sd"> B0 = torch.rand((256, 64), device='cuda')</span>
|
||||
<span class="sd"> C0 = torch.zeros((128, 64), device='cuda')</span>
|
||||
<span class="sd"> D0 = torch.zeros((128, 64), device.'cuda')</span>
|
||||
<span class="sd"> plan.run(A0, B0, C0, D0)</span>
|
||||
|
||||
<span class="sd"> A = torch.rand((32, 128), device='cuda')</span>
|
||||
<span class="sd"> B = torch.rand((128, 256), device='cuda')</span>
|
||||
<span class="sd"> C = torch.zeros((32, 256), device='cuda')</span>
|
||||
<span class="sd"> D = torch.zeros((32, 256), device.'cuda')</span>
|
||||
<span class="sd"> plan.run(A1, B1, C1, D1)</span>
|
||||
|
||||
<span class="sd"> The interface additionally enables one to decouple the compilation of the underlying CUTLASS</span>
|
||||
<span class="sd"> kernel from its execution:</span>
|
||||
|
||||
<span class="sd"> .. highlight:: python</span>
|
||||
<span class="sd"> .. code-block:: python</span>
|
||||
|
||||
<span class="sd"> plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor)</span>
|
||||
<span class="sd"> plan.compile()</span>
|
||||
|
||||
<span class="sd"> # Do other work...</span>
|
||||
|
||||
<span class="sd"> plan.run(A0, B0, C0, D0)</span>
|
||||
|
||||
<span class="sd"> # Do other work...</span>
|
||||
|
||||
<span class="sd"> plan.run(A1, B1, C1, D1)</span>
|
||||
|
||||
<span class="sd"> Elementwise activation functions are easily fused to the GEMM via the interface:</span>
|
||||
|
||||
<span class="sd"> .. highlight:: python</span>
|
||||
<span class="sd"> .. code-block:: python</span>
|
||||
|
||||
<span class="sd"> plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor)</span>
|
||||
<span class="sd"> plan.activation = cutlass.epilogue.relu</span>
|
||||
|
||||
<span class="sd"> Operations can also be run asynchronously:</span>
|
||||
|
||||
<span class="sd"> .. highlight:: python</span>
|
||||
<span class="sd"> .. code-block:: python</span>
|
||||
|
||||
<span class="sd"> plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor)</span>
|
||||
<span class="sd"> args = plan.run()</span>
|
||||
|
||||
<span class="sd"> # Do other work...</span>
|
||||
|
||||
<span class="sd"> args.sync()</span>
|
||||
<span class="sd">"""</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">cutlass_bindings</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">cutlass</span>
|
||||
<span class="kn">from</span> <span class="nn">cutlass</span> <span class="kn">import</span> <span class="n">epilogue</span><span class="p">,</span> <span class="n">swizzle</span>
|
||||
<span class="kn">from</span> <span class="nn">cutlass.backend</span> <span class="kn">import</span> <span class="n">compiler</span>
|
||||
<span class="kn">from</span> <span class="nn">cutlass.backend.gemm_operation</span> <span class="kn">import</span> <span class="n">GemmArguments</span><span class="p">,</span> <span class="n">GemmOperationUniversal</span>
|
||||
<span class="kn">from</span> <span class="nn">cutlass.backend.library</span> <span class="kn">import</span> <span class="n">TensorDescription</span><span class="p">,</span> <span class="n">TileDescription</span>
|
||||
<span class="kn">from</span> <span class="nn">cutlass.op.op</span> <span class="kn">import</span> <span class="n">OperationBase</span>
|
||||
<span class="kn">from</span> <span class="nn">cutlass.utils</span> <span class="kn">import</span> <span class="n">check</span><span class="p">,</span> <span class="n">datatypes</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="Gemm"><a class="viewcode-back" href="../../../cutlass.op.html#cutlass.op.gemm.Gemm">[docs]</a><span class="k">class</span> <span class="nc">Gemm</span><span class="p">(</span><span class="n">OperationBase</span><span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Constructs a ``Gemm`` object.</span>
|
||||
|
||||
<span class="sd"> The data types and layouts of operands A, B, and C, along with the data type of output D</span>
|
||||
<span class="sd"> and that used for accumulation, are bound to the ``Gemm`` object throughout its lifetime --</span>
|
||||
<span class="sd"> these are not to be changed after a ``Gemm`` has been constructed.</span>
|
||||
|
||||
<span class="sd"> The constructor has optional parameters for flexibly setting these parameters. The following</span>
|
||||
<span class="sd"> constructors are equivalent:</span>
|
||||
|
||||
<span class="sd"> .. highlight:: python</span>
|
||||
<span class="sd"> .. code-block:: python</span>
|
||||
|
||||
<span class="sd"> # Use F32 for A, B, C, D, and accumulation. All operands are row major.</span>
|
||||
|
||||
<span class="sd"> # Use the generic ``element`` and ``layout`` parameters to concisely set all data types and layouts</span>
|
||||
<span class="sd"> # for operands to the same values.</span>
|
||||
<span class="sd"> Gemm(element=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor)</span>
|
||||
|
||||
<span class="sd"> # Explicitly specify the data types to use for A, B, C, and D. Use the generic ``layout``.</span>
|
||||
<span class="sd"> Gemm(element_A=cutlass.DataType.f32, element_B=cutlass.DataType.f32, element_C=cutlass.DataType.f32,</span>
|
||||
<span class="sd"> element_D=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor)</span>
|
||||
|
||||
<span class="sd"> # Set the data types and elements from existing tensors. Note that one can use different tensors when</span>
|
||||
<span class="sd"> # executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must</span>
|
||||
<span class="sd"> # have the same data type and layout as those passed in here).</span>
|
||||
<span class="sd"> # A, B, C, and D are row-major torch.Tensor objects of type torch.float32</span>
|
||||
<span class="sd"> Gemm(A=A, B=B, C=C, D=D)</span>
|
||||
|
||||
<span class="sd"> # Use the generic ``element`` and explicitly specify the layouts to use for A, B, and C (layout of D is</span>
|
||||
<span class="sd"> # the same as that for D, at present)</span>
|
||||
<span class="sd"> Gemm(element=cutlass.DataType.f32, layout_A=cutlass.LayoutType.RowMajor,</span>
|
||||
<span class="sd"> layout_B=cutlass.LayoutType.RowMajor, layout_C=cutlass.LayoutType.RowMajor)</span>
|
||||
|
||||
<span class="sd"> # Explicitly specify the data type and layout for only some of A, B, C, and D. Unspecified data types</span>
|
||||
<span class="sd"> # and layouts will inherit those passed in via the generic ``element`` and ``layout``</span>
|
||||
<span class="sd"> Gemm(element_A=cutlass.DataType.f32, layout_B=cutlass.LayoutType.RowMajor,</span>
|
||||
<span class="sd"> element=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor)</span>
|
||||
|
||||
<span class="sd"> The order of precedence for the setting of the data type and layout for a given operand/output is as follows:</span>
|
||||
<span class="sd"> 1) If the tensor type is specified (e.g., ``A``), use the data type and layout inferred from this tensor</span>
|
||||
<span class="sd"> 2) Otherwise, if the data type/layout (e.g., ``element_A``, ``layout_A``) is specified, use those</span>
|
||||
<span class="sd"> 3) Otherwise, use the generic values (e.g., ``element``, ``layout``)</span>
|
||||
|
||||
<span class="sd"> :param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90</span>
|
||||
<span class="sd"> :type cc: int</span>
|
||||
<span class="sd"> :param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80</span>
|
||||
<span class="sd"> :type kernel_cc: int</span>
|
||||
<span class="sd"> :param A: tensor representing data type and layout of operand A</span>
|
||||
<span class="sd"> :param B: tensor representing data type and layout of operand B</span>
|
||||
<span class="sd"> :param C: tensor representing data type and layout of operand C</span>
|
||||
<span class="sd"> :param D: tensor representing data type and layout of operand D</span>
|
||||
<span class="sd"> :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B</span>
|
||||
<span class="sd"> :param beta: scalar parameter beta from GEMM operation that scales operand C</span>
|
||||
<span class="sd"> :param element_accumulator: data type to be used in accumulation of the product of operands A and B</span>
|
||||
<span class="sd"> :type element_accumulator: cutlass.DataType</span>
|
||||
<span class="sd"> :param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type</span>
|
||||
<span class="sd"> :type element: cutlass.DataType</span>
|
||||
<span class="sd"> :param layout: generic layout type to be used for operands A, B, C, and D</span>
|
||||
<span class="sd"> :type layout: cutlass.LayoutType</span>
|
||||
<span class="sd"> :param element_A: data type to be used for operand A</span>
|
||||
<span class="sd"> :type element_A: cutlass.DataType</span>
|
||||
<span class="sd"> :param element_B: data type to be used for operand B</span>
|
||||
<span class="sd"> :type element_B: cutlass.DataType</span>
|
||||
<span class="sd"> :param element_C: data type to be used for operand C</span>
|
||||
<span class="sd"> :type element_C: cutlass.DataType</span>
|
||||
<span class="sd"> :param element_D: data type to be used for operand D</span>
|
||||
<span class="sd"> :type element_D: cutlass.DataType</span>
|
||||
<span class="sd"> :type layout_A: layout of operand A</span>
|
||||
<span class="sd"> :param layout_A: cutlass.LayoutType</span>
|
||||
<span class="sd"> :type layout_B: layout of operand B</span>
|
||||
<span class="sd"> :param layout_B: cutlass.LayoutType</span>
|
||||
<span class="sd"> :type layout_C: layout of operand C</span>
|
||||
<span class="sd"> :param layout_C: cutlass.LayoutType</span>
|
||||
<span class="sd"> :type layout_D: layout of operand D</span>
|
||||
<span class="sd"> :param layout_D: cutlass.LayoutType</span>
|
||||
<span class="sd"> """</span>
|
||||
|
||||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="p">,</span> <span class="n">A</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">B</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">C</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">D</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">alpha</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">beta</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">element_accumulator</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">element</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">layout</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">element_A</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">element_B</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">element_C</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">element_D</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">layout_A</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">layout_B</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">layout_C</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">cc</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">kernel_cc</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">cc</span><span class="o">=</span><span class="n">cc</span><span class="p">,</span> <span class="n">kernel_cc</span><span class="o">=</span><span class="n">kernel_cc</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="s2">"gemm"</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">compiled</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
|
||||
<span class="n">elements</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="n">layouts</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
|
||||
<span class="c1"># Check that at least one of the following is set for each tensor (illustrated assuming tensor A):</span>
|
||||
<span class="c1"># ``A``, ``element_A``, ``element`` and ``A``, ``layout_A``, ``layout``</span>
|
||||
<span class="k">for</span> <span class="n">elt</span><span class="p">,</span> <span class="n">lay</span><span class="p">,</span> <span class="n">tens</span><span class="p">,</span> <span class="n">name</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">([</span><span class="n">element_A</span><span class="p">,</span> <span class="n">element_B</span><span class="p">,</span> <span class="n">element_C</span><span class="p">,</span> <span class="n">element_D</span><span class="p">],</span>
|
||||
<span class="p">[</span><span class="n">layout_A</span><span class="p">,</span> <span class="n">layout_B</span><span class="p">,</span> <span class="n">layout_C</span><span class="p">,</span> <span class="n">layout_C</span><span class="p">],</span>
|
||||
<span class="p">[</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">D</span><span class="p">],</span>
|
||||
<span class="p">[</span><span class="s2">"A"</span><span class="p">,</span> <span class="s2">"B"</span><span class="p">,</span> <span class="s2">"C"</span><span class="p">,</span> <span class="s2">"D"</span><span class="p">]):</span>
|
||||
<span class="k">if</span> <span class="n">elt</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">tens</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="sa">f</span><span class="s1">'Must not specify both element_</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s1"> and tensor </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">lay</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">tens</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="sa">f</span><span class="s1">'Must not specify both layout_</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s1"> and tensor </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">elt</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">tens</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">element</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="sa">f</span><span class="s1">'Must specify one of element_</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s1">, tensor </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s1">, or generic element.'</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">lay</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">tens</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">layout</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="sa">f</span><span class="s1">'Must specify one of layout_</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s1">, tensor </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s1">, or generic layout.'</span><span class="p">)</span>
|
||||
|
||||
<span class="n">elt_to_set</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="n">lay_to_set</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="k">if</span> <span class="n">tens</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">elt_to_set</span><span class="p">,</span> <span class="n">lay_to_set</span> <span class="o">=</span> <span class="n">datatypes</span><span class="o">.</span><span class="n">get_datatype_and_layout</span><span class="p">(</span><span class="n">tens</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">elt_to_set</span> <span class="o">=</span> <span class="n">elt</span> <span class="k">if</span> <span class="n">elt</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">element</span>
|
||||
<span class="n">lay_to_set</span> <span class="o">=</span> <span class="n">lay</span> <span class="k">if</span> <span class="n">lay</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">layout</span>
|
||||
|
||||
<span class="n">elements</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">datatypes</span><span class="o">.</span><span class="n">library_type</span><span class="p">(</span><span class="n">elt_to_set</span><span class="p">))</span>
|
||||
<span class="n">layouts</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">datatypes</span><span class="o">.</span><span class="n">library_layout</span><span class="p">(</span><span class="n">lay_to_set</span><span class="p">))</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_element_a</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_element_b</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_element_c</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_element_d</span> <span class="o">=</span> <span class="n">elements</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_layout_a</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layout_b</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layout_c</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layout_d</span> <span class="o">=</span> <span class="n">layouts</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">element_accumulator</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_element_accumulator</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_element_c</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_element_accumulator</span> <span class="o">=</span> <span class="n">datatypes</span><span class="o">.</span><span class="n">library_type</span><span class="p">(</span><span class="n">element_accumulator</span><span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">A</span> <span class="o">=</span> <span class="n">A</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">B</span> <span class="o">=</span> <span class="n">B</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">C</span> <span class="o">=</span> <span class="n">C</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">D</span> <span class="o">=</span> <span class="n">D</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">alpha</span> <span class="o">=</span> <span class="n">alpha</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">beta</span> <span class="o">=</span> <span class="n">beta</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">epilogue_functor</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">op_class</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_reset_operations</span><span class="p">()</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_swizzling_functor</span> <span class="o">=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">swizzle</span><span class="o">.</span><span class="n">IdentitySwizzle1</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_reset_operations</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">reset_epilogue</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
|
||||
<span class="c1"># Set the default op class</span>
|
||||
<span class="n">datatype_comb</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_element_a</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_element_b</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_element_accumulator</span><span class="p">)</span>
|
||||
<span class="n">layout_comb</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_layout_a</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layout_b</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">possible_op_classes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">options</span><span class="o">.</span><span class="n">supporting_opclasses</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_element_a</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_element_b</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_element_accumulator</span><span class="p">,</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_layout_a</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layout_b</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">OpcodeClass</span><span class="o">.</span><span class="n">TensorOp</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">possible_op_classes</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">opclass</span> <span class="o">=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">OpcodeClass</span><span class="o">.</span><span class="n">TensorOp</span>
|
||||
<span class="k">elif</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">OpcodeClass</span><span class="o">.</span><span class="n">Simt</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">possible_op_classes</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">opclass</span> <span class="o">=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">OpcodeClass</span><span class="o">.</span><span class="n">Simt</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="sa">f</span><span class="s1">'No kernel configuration found for supported data type and layout '</span>
|
||||
<span class="sa">f</span><span class="s1">'combination </span><span class="si">{</span><span class="n">datatype_comb</span><span class="si">}</span><span class="s1">x</span><span class="si">{</span><span class="n">layout_comb</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">reset_epilogue</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_reset_epilogue_functor_activation</span><span class="p">(</span><span class="n">epilogue</span><span class="o">.</span><span class="n">identity</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_reset_epilogue_functor_activation</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">activation</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">epilogue_functor</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">op_class</span> <span class="o">==</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">OpcodeClass</span><span class="o">.</span><span class="n">Simt</span><span class="p">:</span>
|
||||
<span class="n">elements_per_access</span> <span class="o">=</span> <span class="mi">1</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">elements_per_access</span> <span class="o">=</span> <span class="mi">128</span> <span class="o">//</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataTypeSize</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_element_c</span><span class="p">]</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">elements_per_access</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">epilogue_functor</span><span class="o">.</span><span class="n">epilogue_vector_length</span>
|
||||
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">specified_kernel_cc</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_cc</span> <span class="o">==</span> <span class="mi">90</span> <span class="ow">and</span> <span class="n">activation</span> <span class="o">!=</span> <span class="n">epilogue</span><span class="o">.</span><span class="n">identity</span><span class="p">:</span>
|
||||
<span class="c1"># CUTLASS 3.0 kernels currently only support identity activation. If one requests a non-identity activation,</span>
|
||||
<span class="c1"># revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">"Reverting to using SM80-tagged kernel. Opclass may change."</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_reset_options</span><span class="p">(</span><span class="mi">80</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_reset_operations</span><span class="p">(</span><span class="n">reset_epilogue</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
<span class="k">elif</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cc</span> <span class="o">==</span> <span class="mi">90</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_cc</span> <span class="o">!=</span> <span class="mi">90</span> <span class="ow">and</span> <span class="n">activation</span> <span class="o">==</span> <span class="n">epilogue</span><span class="o">.</span><span class="n">identity</span><span class="p">):</span>
|
||||
<span class="c1"># SM80 fallback kernels are currently used. Since an identity activation is requested,</span>
|
||||
<span class="c1"># we can switch back to using SM90 kernels.</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_reset_options</span><span class="p">(</span><span class="mi">90</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_reset_operations</span><span class="p">(</span><span class="n">reset_epilogue</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_cc</span> <span class="o">==</span> <span class="mi">90</span> <span class="ow">and</span> <span class="n">activation</span> <span class="o">!=</span> <span class="n">epilogue</span><span class="o">.</span><span class="n">identity</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s2">"Epilogues with elementwise fusion are not currently supported "</span>
|
||||
<span class="s2">"in the Python interface for 3.x kernels. To use 2.x kernels "</span>
|
||||
<span class="s2">"with fused elementwise epilogues, do not set the `kernel_cc` "</span>
|
||||
<span class="s2">"parameter when constructing the Gemm object."</span><span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">epilogue_functor</span> <span class="o">=</span> <span class="n">epilogue</span><span class="o">.</span><span class="n">get_activation_epilogue</span><span class="p">(</span>
|
||||
<span class="n">activation</span><span class="p">,</span>
|
||||
<span class="n">datatypes</span><span class="o">.</span><span class="n">binding_type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_element_c</span><span class="p">),</span>
|
||||
<span class="n">elements_per_access</span><span class="p">,</span>
|
||||
<span class="n">datatypes</span><span class="o">.</span><span class="n">binding_type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_element_accumulator</span><span class="p">),</span>
|
||||
<span class="n">datatypes</span><span class="o">.</span><span class="n">binding_type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_element_accumulator</span><span class="p">),</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_reset_epilogue_functor_alignment</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">alignment</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">epilogue_functor</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="ow">not</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">epilogue_functor</span><span class="p">,</span> <span class="s1">'activation_functor'</span><span class="p">):</span>
|
||||
<span class="n">activation</span> <span class="o">=</span> <span class="n">epilogue</span><span class="o">.</span><span class="n">identity</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">activation</span> <span class="o">=</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">epilogue_functor</span><span class="o">.</span><span class="n">activation_functor</span><span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">epilogue_functor</span> <span class="o">=</span> <span class="n">epilogue</span><span class="o">.</span><span class="n">get_activation_epilogue</span><span class="p">(</span>
|
||||
<span class="n">activation</span><span class="p">,</span>
|
||||
<span class="n">datatypes</span><span class="o">.</span><span class="n">binding_type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_element_c</span><span class="p">),</span>
|
||||
<span class="n">alignment</span><span class="p">,</span>
|
||||
<span class="n">datatypes</span><span class="o">.</span><span class="n">binding_type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_element_accumulator</span><span class="p">),</span>
|
||||
<span class="n">datatypes</span><span class="o">.</span><span class="n">binding_type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_element_accumulator</span><span class="p">),</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">activation</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Returns the type of the current activation function used</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">return</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">epilogue_functor</span><span class="o">.</span><span class="n">activation_functor</span><span class="p">)</span>
|
||||
|
||||
<span class="nd">@activation</span><span class="o">.</span><span class="n">setter</span>
|
||||
<span class="k">def</span> <span class="nf">activation</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">act</span><span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Sets the type of the activation function to use</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_reset_epilogue_functor_activation</span><span class="p">(</span><span class="n">act</span><span class="p">)</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">opclass</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">cutlass</span><span class="o">.</span><span class="n">OpcodeClass</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Returns the opcode class currently in use by the GEMM</span>
|
||||
|
||||
<span class="sd"> :return: opcode class currently in use</span>
|
||||
<span class="sd"> :rtype: cutlass.OpcodeClass</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">op_class</span>
|
||||
|
||||
<span class="nd">@opclass</span><span class="o">.</span><span class="n">setter</span>
|
||||
<span class="k">def</span> <span class="nf">opclass</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">oc</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">OpcodeClass</span><span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Sets the opcode class to use in the GEMM. If the opcode class is not supported under</span>
|
||||
<span class="sd"> the given compute capability and element/layout combinations of the GEMM, an exception is raised.</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">if</span> <span class="n">oc</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">possible_op_classes</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">op_class</span> <span class="o">=</span> <span class="n">oc</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span>
|
||||
<span class="sa">f</span><span class="s1">'Unsupported operation class </span><span class="si">{</span><span class="n">oc</span><span class="si">}</span><span class="s1"> for CC </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">cc</span><span class="si">}</span><span class="s1"> and data type combination '</span>
|
||||
<span class="sa">f</span><span class="s1">'(</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_element_a</span><span class="si">}</span><span class="s1">, </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_element_b</span><span class="si">}</span><span class="s1">, </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_element_accumulator</span><span class="si">}</span><span class="s1">) and '</span>
|
||||
<span class="sa">f</span><span class="s1">'layout combination (</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_layout_a</span><span class="si">}</span><span class="s1">, </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_layout_b</span><span class="si">}</span><span class="s1">).'</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># Changing the op class changes the elements per access in the epilogue. Reset this.</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">op_class</span> <span class="o">==</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">OpcodeClass</span><span class="o">.</span><span class="n">Simt</span><span class="p">:</span>
|
||||
<span class="n">elements_per_access</span> <span class="o">=</span> <span class="mi">1</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">elements_per_access</span> <span class="o">=</span> <span class="mi">128</span> <span class="o">//</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataTypeSize</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_element_c</span><span class="p">]</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">epilogue_functor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_reset_epilogue_functor_alignment</span><span class="p">(</span><span class="n">elements_per_access</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># Changing the op class also changes the possible operations available. Reset these.</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">possible_operations</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">options</span><span class="o">.</span><span class="n">operations</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">op_class</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_element_a</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_element_b</span><span class="p">,</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_element_accumulator</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layout_a</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layout_b</span><span class="p">)</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">swizzling_functor</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Returns the type of the swizzling functor currently being used by the GEMM</span>
|
||||
|
||||
<span class="sd"> :return: swizzing functor type</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_swizzling_functor</span>
|
||||
|
||||
<span class="nd">@swizzling_functor</span><span class="o">.</span><span class="n">setter</span>
|
||||
<span class="k">def</span> <span class="nf">swizzling_functor</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">swizzling_functor</span><span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Sets the swizzling functor to the type specified by `swizzling_functor`</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">if</span> <span class="n">swizzling_functor</span> <span class="o">==</span> <span class="n">swizzle</span><span class="o">.</span><span class="n">ThreadblockSwizzleStreamK</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">op_class</span> <span class="o">==</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">OpcodeClass</span><span class="o">.</span><span class="n">Simt</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s1">'ThreadblockSwizzleStreamK is currently only supported with opcode class TensorOp'</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_cc</span> <span class="o">==</span> <span class="mi">90</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s1">'ThreadblockSwizzleStreamK is currently unsupported on SM90'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_swizzling_functor</span> <span class="o">=</span> <span class="n">swizzling_functor</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_valid_tile_description</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">td</span><span class="p">:</span> <span class="n">TileDescription</span><span class="p">)</span> <span class="o">-></span> <span class="nb">tuple</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Checks whether the provided tile description is valid for the given compute capability. At present,</span>
|
||||
<span class="sd"> this checks the following:</span>
|
||||
|
||||
<span class="sd"> - Does the tile description use a number of stages supported by the compute capability in question?</span>
|
||||
<span class="sd"> - Does the tile size requested fit within shared memory?</span>
|
||||
<span class="sd"> - Are cluster dimensions outside the valid range requested for a given architecture (e.g.,</span>
|
||||
<span class="sd"> more non-unit cluster dimensions for pre-SM90 architectures)?</span>
|
||||
<span class="sd"> - Is the kernel schedule being used supported on the architecture in question?</span>
|
||||
|
||||
<span class="sd"> :param td: tile description to validate</span>
|
||||
<span class="sd"> :type td: cutlass.backend.TileDescription</span>
|
||||
<span class="sd"> :return: tuple in which the first element is a bool indicating that the tile description is valid</span>
|
||||
<span class="sd"> and the second element is a string providing an optional error message.</span>
|
||||
<span class="sd"> :rtype: tuple</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="c1"># Check stage count based on the CC to which we are compiling (self.cc), rather</span>
|
||||
<span class="c1"># than the CC from which we find kernels (self.current_cc)</span>
|
||||
<span class="n">valid</span><span class="p">,</span> <span class="n">msg</span> <span class="o">=</span> <span class="n">check</span><span class="o">.</span><span class="n">valid_stage_count</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cc</span><span class="p">,</span> <span class="n">td</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="n">valid</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="p">(</span><span class="n">valid</span><span class="p">,</span> <span class="n">msg</span><span class="p">)</span>
|
||||
|
||||
<span class="n">valid</span><span class="p">,</span> <span class="n">msg</span> <span class="o">=</span> <span class="n">check</span><span class="o">.</span><span class="n">valid_cluster_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">current_cc</span><span class="p">,</span> <span class="n">td</span><span class="o">.</span><span class="n">cluster_shape</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="n">valid</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="p">(</span><span class="n">valid</span><span class="p">,</span> <span class="n">msg</span><span class="p">)</span>
|
||||
|
||||
<span class="n">valid</span><span class="p">,</span> <span class="n">msg</span> <span class="o">=</span> <span class="n">check</span><span class="o">.</span><span class="n">valid_kernel_schedule</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">current_cc</span><span class="p">,</span> <span class="n">td</span><span class="o">.</span><span class="n">kernel_schedule</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">valid</span><span class="p">,</span> <span class="n">msg</span>
|
||||
|
||||
<div class="viewcode-block" id="Gemm.tile_descriptions"><a class="viewcode-back" href="../../../cutlass.op.html#cutlass.op.gemm.Gemm.tile_descriptions">[docs]</a> <span class="k">def</span> <span class="nf">tile_descriptions</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">list</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Returns a list of valid tile descriptions for the operations</span>
|
||||
|
||||
<span class="sd"> :returns: list of valid tile descriptions for the operations</span>
|
||||
<span class="sd"> :rtype: list</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">return</span> <span class="p">[</span><span class="n">datatypes</span><span class="o">.</span><span class="n">td_from_profiler_op</span><span class="p">(</span><span class="n">op</span><span class="p">)</span> <span class="k">for</span> <span class="n">op</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">possible_operations</span><span class="o">.</span><span class="n">all_operations</span><span class="p">]</span></div>
|
||||
|
||||
<div class="viewcode-block" id="Gemm.construct"><a class="viewcode-back" href="../../../cutlass.op.html#cutlass.op.gemm.Gemm.construct">[docs]</a> <span class="k">def</span> <span class="nf">construct</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="p">,</span> <span class="n">tile_description</span><span class="p">:</span> <span class="n">TileDescription</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">alignment_A</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">alignment_B</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">alignment_C</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">GemmOperationUniversal</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Constructs a ``cutlass.backend.GemmUniversalOperation`` based on the input parameters and current</span>
|
||||
<span class="sd"> kernel specification of the ``Gemm`` object.</span>
|
||||
|
||||
<span class="sd"> :param tile_description: tile description specifying shapes and operand types to use in the kernel</span>
|
||||
<span class="sd"> :type tile_description: cutlass.backend.TileDescription</span>
|
||||
<span class="sd"> :param alignment_A: alignment of operand A</span>
|
||||
<span class="sd"> :type alignment_A: int</span>
|
||||
<span class="sd"> :param alignment_B: alignment of operand B</span>
|
||||
<span class="sd"> :type alignment_B: int</span>
|
||||
<span class="sd"> :param alignment_C: alignment of operand C</span>
|
||||
<span class="sd"> :type alignment_C: int</span>
|
||||
|
||||
<span class="sd"> :return: operation that was constructed</span>
|
||||
<span class="sd"> :rtype: cutlass.backend.GemmOperationUniversal</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="n">alignment_pref_A</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="mi">128</span> <span class="o">//</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataTypeSize</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_element_a</span><span class="p">],</span> <span class="nb">max</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">possible_operations</span><span class="o">.</span><span class="n">alignments</span><span class="p">))</span>
|
||||
<span class="n">alignment_pref_B</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="mi">128</span> <span class="o">//</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataTypeSize</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_element_b</span><span class="p">],</span> <span class="nb">max</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">possible_operations</span><span class="o">.</span><span class="n">alignments</span><span class="p">))</span>
|
||||
<span class="n">alignment_pref_C</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="mi">128</span> <span class="o">//</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataTypeSize</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_element_c</span><span class="p">],</span> <span class="nb">max</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">possible_operations</span><span class="o">.</span><span class="n">alignments</span><span class="p">))</span>
|
||||
<span class="n">alignment_A</span> <span class="o">=</span> <span class="n">check</span><span class="o">.</span><span class="n">alignment_or_default</span><span class="p">(</span><span class="n">alignment_A</span><span class="p">,</span> <span class="n">alignment_pref_A</span><span class="p">)</span>
|
||||
<span class="n">alignment_B</span> <span class="o">=</span> <span class="n">check</span><span class="o">.</span><span class="n">alignment_or_default</span><span class="p">(</span><span class="n">alignment_B</span><span class="p">,</span> <span class="n">alignment_pref_B</span><span class="p">)</span>
|
||||
<span class="n">alignment_C</span> <span class="o">=</span> <span class="n">check</span><span class="o">.</span><span class="n">alignment_or_default</span><span class="p">(</span><span class="n">alignment_C</span><span class="p">,</span> <span class="n">alignment_pref_C</span><span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_reset_epilogue_functor_alignment</span><span class="p">(</span><span class="n">alignment_C</span><span class="p">)</span>
|
||||
|
||||
<span class="n">tensor_A</span> <span class="o">=</span> <span class="n">TensorDescription</span><span class="p">(</span>
|
||||
<span class="n">datatypes</span><span class="o">.</span><span class="n">binding_type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_element_a</span><span class="p">),</span>
|
||||
<span class="n">datatypes</span><span class="o">.</span><span class="n">binding_layout</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_layout_a</span><span class="p">),</span>
|
||||
<span class="n">alignment_A</span>
|
||||
<span class="p">)</span>
|
||||
<span class="n">tensor_B</span> <span class="o">=</span> <span class="n">TensorDescription</span><span class="p">(</span>
|
||||
<span class="n">datatypes</span><span class="o">.</span><span class="n">binding_type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_element_b</span><span class="p">),</span>
|
||||
<span class="n">datatypes</span><span class="o">.</span><span class="n">binding_layout</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_layout_b</span><span class="p">),</span>
|
||||
<span class="n">alignment_B</span>
|
||||
<span class="p">)</span>
|
||||
<span class="n">tensor_C</span> <span class="o">=</span> <span class="n">TensorDescription</span><span class="p">(</span>
|
||||
<span class="n">datatypes</span><span class="o">.</span><span class="n">binding_type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_element_c</span><span class="p">),</span>
|
||||
<span class="n">datatypes</span><span class="o">.</span><span class="n">binding_layout</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_layout_c</span><span class="p">),</span>
|
||||
<span class="n">alignment_C</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">tile_description</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">op</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">possible_operations</span><span class="o">.</span><span class="n">operations</span><span class="p">(</span><span class="n">alignment_A</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="n">tile_description</span> <span class="o">=</span> <span class="n">datatypes</span><span class="o">.</span><span class="n">td_from_profiler_op</span><span class="p">(</span><span class="n">op</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">valid</span><span class="p">,</span> <span class="n">err_str</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_valid_tile_description</span><span class="p">(</span><span class="n">tile_description</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="n">valid</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Invalid tile description. </span><span class="si">{</span><span class="n">err_str</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">tile_description</span> <span class="o">=</span> <span class="n">tile_description</span>
|
||||
|
||||
<span class="n">operation</span> <span class="o">=</span> <span class="n">GemmOperationUniversal</span><span class="p">(</span>
|
||||
<span class="n">arch</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">current_cc</span><span class="p">,</span>
|
||||
<span class="n">tile_description</span><span class="o">=</span><span class="n">tile_description</span><span class="p">,</span>
|
||||
<span class="n">A</span><span class="o">=</span><span class="n">tensor_A</span><span class="p">,</span> <span class="n">B</span><span class="o">=</span><span class="n">tensor_B</span><span class="p">,</span> <span class="n">C</span><span class="o">=</span><span class="n">tensor_C</span><span class="p">,</span>
|
||||
<span class="n">epilogue_functor</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">epilogue_functor</span><span class="p">,</span>
|
||||
<span class="n">swizzling_functor</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_swizzling_functor</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">operation</span></div>
|
||||
|
||||
<div class="viewcode-block" id="Gemm.compile"><a class="viewcode-back" href="../../../cutlass.op.html#cutlass.op.gemm.Gemm.compile">[docs]</a> <span class="k">def</span> <span class="nf">compile</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">tile_description</span><span class="p">:</span> <span class="n">TileDescription</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">alignment_A</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">alignment_B</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">alignment_C</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">print_module</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="n">cutlass</span><span class="o">.</span><span class="n">backend</span><span class="o">.</span><span class="n">GemmOperationUniversal</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Emits and compiles the kernel currently specified. If ``tile_description`` and any</span>
|
||||
<span class="sd"> of the ``alignment`` parameters are set, the kernel will be chosen using this</span>
|
||||
<span class="sd"> tile description and alignments. Otherwise, a default tile description and alignment</span>
|
||||
<span class="sd"> will be used.</span>
|
||||
|
||||
<span class="sd"> :param tile_description: tile description specifying shapes and operand types to use in the kernel</span>
|
||||
<span class="sd"> :type tile_description: cutlass.backend.TileDescription</span>
|
||||
<span class="sd"> :param alignment_A: alignment of operand A</span>
|
||||
<span class="sd"> :type alignment_A: int</span>
|
||||
<span class="sd"> :param alignment_B: alignment of operand B</span>
|
||||
<span class="sd"> :type alignment_B: int</span>
|
||||
<span class="sd"> :param alignment_C: alignment of operand C</span>
|
||||
<span class="sd"> :type alignment_C: int</span>
|
||||
<span class="sd"> :param print_module: whether to print the emitted C++ code</span>
|
||||
<span class="sd"> :type print_module: bool</span>
|
||||
|
||||
<span class="sd"> :return: operation that was compiled</span>
|
||||
<span class="sd"> :rtype: cutlass.backend.GemmOperationUniversal</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">operation</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">construct</span><span class="p">(</span><span class="n">tile_description</span><span class="p">,</span> <span class="n">alignment_A</span><span class="p">,</span> <span class="n">alignment_B</span><span class="p">,</span> <span class="n">alignment_C</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">print_module</span><span class="p">:</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">operation</span><span class="o">.</span><span class="n">rt_module</span><span class="o">.</span><span class="n">emit</span><span class="p">())</span>
|
||||
|
||||
<span class="n">compiler</span><span class="o">.</span><span class="n">add_module</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">operation</span><span class="p">,])</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">operation</span></div>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_verify_type_and_layout</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">tensor</span><span class="p">,</span> <span class="n">ref_type</span><span class="p">,</span> <span class="n">ref_layout</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception</span>
|
||||
<span class="sd"> is raised if it does not.</span>
|
||||
|
||||
<span class="sd"> :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in</span>
|
||||
<span class="sd"> :type tensor: numpy/cupy/torch array/tensor object</span>
|
||||
<span class="sd"> :param ref_dtype: data type for the tensor that this object was initialized to</span>
|
||||
<span class="sd"> :param ref_layout: layout for the tensor that this object was initialized to</span>
|
||||
<span class="sd"> :param name: identifier of the tensor to verify. Used in raising exceptions</span>
|
||||
<span class="sd"> :type name: str</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="n">dtype</span><span class="p">,</span> <span class="n">layout</span> <span class="o">=</span> <span class="n">datatypes</span><span class="o">.</span><span class="n">get_datatype_and_layout</span><span class="p">(</span><span class="n">tensor</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">dtype</span> <span class="o">!=</span> <span class="n">ref_type</span> <span class="ow">or</span> <span class="n">layout</span> <span class="o">!=</span> <span class="n">ref_layout</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="sa">f</span><span class="s1">'Tensor </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s1"> with type and layout (</span><span class="si">{</span><span class="n">dtype</span><span class="si">}</span><span class="s1">, </span><span class="si">{</span><span class="n">layout</span><span class="si">}</span><span class="s1">) '</span>
|
||||
<span class="sa">f</span><span class="s1">'does not match the expected type and '</span>
|
||||
<span class="sa">f</span><span class="s1">'layout of (</span><span class="si">{</span><span class="n">ref_type</span><span class="si">}</span><span class="s1">, </span><span class="si">{</span><span class="n">ref_layout</span><span class="si">}</span><span class="s1">).'</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_verify_tensor</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">tensor</span><span class="p">,</span> <span class="n">ref_tensor</span><span class="p">,</span> <span class="n">ref_dtype</span><span class="p">,</span> <span class="n">ref_layout</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Verifies the following properties:</span>
|
||||
<span class="sd"> 1) Either ``tensor`` or ``ref_tensor`` must be set (i.e., not ``None``)</span>
|
||||
<span class="sd"> 2) If ``tensor`` is not ``None``, its datatype and layout must match matches the current versions</span>
|
||||
<span class="sd"> set by the plan (i.e., those in ``ref_dtype`` and ``ref_layout``)</span>
|
||||
|
||||
<span class="sd"> If either of these properties does not hold, an exception is raised. If these properties hold and</span>
|
||||
<span class="sd"> ``tensor`` is not ``None``, ``tensor`` is returned. Otherwise, ``ref_tensor`` is returned.</span>
|
||||
|
||||
<span class="sd"> :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in</span>
|
||||
<span class="sd"> :type tensor: numpy/cupy/torch array/tensor object</span>
|
||||
<span class="sd"> :param ref_tensor: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in</span>
|
||||
<span class="sd"> :type ref_tensor: numpy/cupy/torch array/tensor object</span>
|
||||
<span class="sd"> :param ref_dtype: data type for the tensor that this object was initialized to</span>
|
||||
<span class="sd"> :param ref_layout: layout for the tensor that this object was initialized to</span>
|
||||
<span class="sd"> :param name: identifier of the tensor to verify. Used in raising exceptions</span>
|
||||
<span class="sd"> :type name: str</span>
|
||||
|
||||
<span class="sd"> :return: valid tensor object to use</span>
|
||||
<span class="sd"> :rtype: numpy/cupy/torch array/tensor object</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">if</span> <span class="n">tensor</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">ref_tensor</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Tensor </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2"> must be set."</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">ref_tensor</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_verify_type_and_layout</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">ref_dtype</span><span class="p">,</span> <span class="n">ref_layout</span><span class="p">,</span> <span class="n">name</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">tensor</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_verify_scalar</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">scalar</span><span class="p">,</span> <span class="n">ref_scalar</span><span class="p">,</span> <span class="n">ref_dtype</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Verifies the following properties:</span>
|
||||
<span class="sd"> 1) Either ``scalar`` or ``ref_scakar`` must be set (i.e., not ``None``)</span>
|
||||
<span class="sd"> 2) If ``scalar`` is not ``None``, its datatype must match matches the current version</span>
|
||||
<span class="sd"> set by the plan (i.e., those in ``ref_dtype``)</span>
|
||||
|
||||
<span class="sd"> If either of these properties does not hold, an exception is raised. If these properties hold and</span>
|
||||
<span class="sd"> ``scalar`` is not ``None``, ``scalar`` is returned. Otherwise, ``ref_scalar`` is returned.</span>
|
||||
|
||||
<span class="sd"> :param scalar: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in</span>
|
||||
<span class="sd"> :type scalar: numpy/cupy/torch scalar</span>
|
||||
<span class="sd"> :param ref_scalar: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in</span>
|
||||
<span class="sd"> :type ref_scalar: numpy/cupy/torch scalar</span>
|
||||
<span class="sd"> :param ref_dtype: data type for the scalar that this object was initialized to</span>
|
||||
<span class="sd"> :param name: identifier of the scalar to verify. Used in raising exceptions</span>
|
||||
<span class="sd"> :type name: str</span>
|
||||
|
||||
<span class="sd"> :return: valid scalar to use</span>
|
||||
<span class="sd"> :rtype: numpy/cupy/torch scalar</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">if</span> <span class="n">scalar</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">ref_scalar</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Scalar </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2"> must be set."</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">ref_scalar</span>
|
||||
<span class="n">dtype</span> <span class="o">=</span> <span class="n">datatypes</span><span class="o">.</span><span class="n">library_type</span><span class="p">(</span><span class="n">scalar</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">dtype</span> <span class="o">!=</span> <span class="n">ref_dtype</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span>
|
||||
<span class="sa">f</span><span class="s2">"Tensor </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2"> with type </span><span class="si">{</span><span class="n">dtype</span><span class="si">}</span><span class="s2"> does not match expected type </span><span class="si">{</span><span class="n">ref_dtype</span><span class="si">}</span><span class="s2">."</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">scalar</span>
|
||||
|
||||
<div class="viewcode-block" id="Gemm.run"><a class="viewcode-back" href="../../../cutlass.op.html#cutlass.op.gemm.Gemm.run">[docs]</a> <span class="k">def</span> <span class="nf">run</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">A</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">B</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">C</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">D</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">alpha</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">beta</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">batch_count</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||||
<span class="n">sync</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> <span class="n">print_module</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="n">GemmArguments</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Runs the kernel currently specified. If it has not already been, the kernel is emitted and</span>
|
||||
<span class="sd"> compiled. Tensors holding operands and outputs of the kernel are sourced either from the</span>
|
||||
<span class="sd"> ``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta``</span>
|
||||
<span class="sd"> parameters provided in this call, or from those</span>
|
||||
<span class="sd"> passed in on the construction of this object -- one of the two must be specified.</span>
|
||||
|
||||
<span class="sd"> By default, this call returns only once the kernel has completed. To launch the kernel</span>
|
||||
<span class="sd"> and immediately return, set ``sync=False``. In this case, it is the responsibility of the</span>
|
||||
<span class="sd"> caller to syncrhonize the results of the kernel before attempting to access outputs</span>
|
||||
<span class="sd"> by calling ``sync()`` on the arguments returned from this call.</span>
|
||||
|
||||
<span class="sd"> :param A: tensor representing data type and layout of operand A</span>
|
||||
<span class="sd"> :param B: tensor representing data type and layout of operand B</span>
|
||||
<span class="sd"> :param C: tensor representing data type and layout of operand C</span>
|
||||
<span class="sd"> :param D: tensor representing data type and layout of operand D</span>
|
||||
<span class="sd"> :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B</span>
|
||||
<span class="sd"> :param beta: scalar parameter beta from GEMM operation that scales operand C</span>
|
||||
<span class="sd"> :param batch_count: number of GEMMs in the batch</span>
|
||||
<span class="sd"> :type batch_count: int</span>
|
||||
<span class="sd"> :param sync: whether the call should wait for the kernel to complete before returning</span>
|
||||
<span class="sd"> :type sync: bool</span>
|
||||
<span class="sd"> :param print_module: whether to print the emitted C++ code</span>
|
||||
<span class="sd"> :type print_module: bool</span>
|
||||
|
||||
<span class="sd"> :return: arguments passed in to the kernel</span>
|
||||
<span class="sd"> :rtype: cutlass.backend.GemmArguments</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">if</span> <span class="n">batch_count</span> <span class="o"><</span> <span class="mi">1</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Invalid batch count </span><span class="si">{</span><span class="n">batch_count</span><span class="si">}</span><span class="s2">. Value must be an integer >= 1."</span><span class="p">)</span>
|
||||
|
||||
<span class="n">A</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_verify_tensor</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">A</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_element_a</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layout_a</span><span class="p">,</span> <span class="s2">"A"</span><span class="p">)</span>
|
||||
<span class="n">B</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_verify_tensor</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">B</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_element_b</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layout_b</span><span class="p">,</span> <span class="s2">"B"</span><span class="p">)</span>
|
||||
<span class="n">C</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_verify_tensor</span><span class="p">(</span><span class="n">C</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">C</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_element_c</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layout_c</span><span class="p">,</span> <span class="s2">"C"</span><span class="p">)</span>
|
||||
<span class="n">D</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_verify_tensor</span><span class="p">(</span><span class="n">D</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">D</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_element_d</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layout_d</span><span class="p">,</span> <span class="s2">"D"</span><span class="p">)</span>
|
||||
<span class="n">alpha</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_verify_scalar</span><span class="p">(</span><span class="n">alpha</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">alpha</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_element_c</span><span class="p">,</span> <span class="s2">"alpha"</span><span class="p">)</span>
|
||||
<span class="n">beta</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_verify_scalar</span><span class="p">(</span><span class="n">beta</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_element_c</span><span class="p">,</span> <span class="s2">"beta"</span><span class="p">)</span>
|
||||
|
||||
<span class="n">alignment_a</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">possible_operations</span><span class="o">.</span><span class="n">find_alignment</span><span class="p">(</span><span class="n">A</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layout_a</span><span class="p">)</span>
|
||||
<span class="n">alignment_b</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">possible_operations</span><span class="o">.</span><span class="n">find_alignment</span><span class="p">(</span><span class="n">B</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layout_b</span><span class="p">)</span>
|
||||
<span class="n">alignment_c</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">possible_operations</span><span class="o">.</span><span class="n">find_alignment</span><span class="p">(</span><span class="n">C</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layout_c</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">tile_description</span><span class="p">,</span> <span class="n">alignment_A</span><span class="o">=</span><span class="n">alignment_a</span><span class="p">,</span> <span class="n">alignment_B</span><span class="o">=</span><span class="n">alignment_b</span><span class="p">,</span>
|
||||
<span class="n">alignment_C</span><span class="o">=</span><span class="n">alignment_c</span><span class="p">,</span> <span class="n">print_module</span><span class="o">=</span><span class="n">print_module</span><span class="p">)</span>
|
||||
|
||||
<span class="n">problem_size</span> <span class="o">=</span> <span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">gemm</span><span class="o">.</span><span class="n">GemmCoord</span><span class="p">(</span><span class="n">A</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">B</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">A</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">batch_count</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||||
<span class="n">mode</span> <span class="o">=</span> <span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">gemm</span><span class="o">.</span><span class="n">Mode</span><span class="o">.</span><span class="n">Gemm</span>
|
||||
<span class="n">kwargs</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'split_k_slices'</span><span class="p">:</span> <span class="mi">1</span><span class="p">}</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">mode</span> <span class="o">=</span> <span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">gemm</span><span class="o">.</span><span class="n">Mode</span><span class="o">.</span><span class="n">Batched</span>
|
||||
<span class="n">kwargs</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'batch'</span><span class="p">:</span> <span class="n">batch_count</span><span class="p">}</span>
|
||||
|
||||
<span class="n">arguments</span> <span class="o">=</span> <span class="n">GemmArguments</span><span class="p">(</span>
|
||||
<span class="n">operation</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">operation</span><span class="p">,</span> <span class="n">problem_size</span><span class="o">=</span><span class="n">problem_size</span><span class="p">,</span>
|
||||
<span class="n">A</span><span class="o">=</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="o">=</span><span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="o">=</span><span class="n">C</span><span class="p">,</span> <span class="n">D</span><span class="o">=</span><span class="n">D</span><span class="p">,</span>
|
||||
<span class="n">output_op</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">operation</span><span class="o">.</span><span class="n">epilogue_type</span><span class="p">(</span><span class="n">alpha</span><span class="p">,</span> <span class="n">beta</span><span class="p">),</span>
|
||||
<span class="n">gemm_mode</span><span class="o">=</span><span class="n">mode</span><span class="p">,</span>
|
||||
<span class="o">**</span><span class="n">kwargs</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">operation</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">arguments</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">sync</span><span class="p">:</span>
|
||||
<span class="n">arguments</span><span class="o">.</span><span class="n">sync</span><span class="p">()</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">arguments</span></div></div>
|
||||
</pre></div>
|
||||
</article>
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
<div class="related-pages">
|
||||
|
||||
|
||||
</div>
|
||||
<div class="bottom-of-page">
|
||||
<div class="left-details">
|
||||
<div class="copyright">
|
||||
Copyright © 2023, NVIDIA
|
||||
</div>
|
||||
Made with <a href="https://www.sphinx-doc.org/">Sphinx</a> and <a class="muted-link" href="https://pradyunsg.me">@pradyunsg</a>'s
|
||||
|
||||
<a href="https://github.com/pradyunsg/furo">Furo</a>
|
||||
|
||||
</div>
|
||||
<div class="right-details">
|
||||
<div class="icons">
|
||||
<a class="muted-link " href="https://github.com/NVIDIA/cutlass" aria-label="GitHub">
|
||||
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 16 16">
|
||||
<path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z"></path>
|
||||
</svg>
|
||||
</a>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</footer>
|
||||
</div>
|
||||
<aside class="toc-drawer no-toc">
|
||||
|
||||
|
||||
|
||||
</aside>
|
||||
</div>
|
||||
</div><script data-url_root="../../../" id="documentation_options" src="../../../_static/documentation_options.js"></script>
|
||||
<script src="../../../_static/doctools.js"></script>
|
||||
<script src="../../../_static/sphinx_highlight.js"></script>
|
||||
<script src="../../../_static/scripts/furo.js"></script>
|
||||
<script src="../../../_static/clipboard.min.js"></script>
|
||||
<script src="../../../_static/copybutton.js"></script>
|
||||
<script src="../../../_static/tabs.js"></script>
|
||||
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
554
python/docs/_modules/cutlass/op/gemm_grouped.html
Normal file
554
python/docs/_modules/cutlass/op/gemm_grouped.html
Normal file
@ -0,0 +1,554 @@
|
||||
<!doctype html>
|
||||
<html class="no-js" lang="en">
|
||||
<head><meta charset="utf-8"/>
|
||||
<meta name="viewport" content="width=device-width,initial-scale=1"/>
|
||||
<meta name="color-scheme" content="light dark"><link rel="index" title="Index" href="../../../genindex.html" /><link rel="search" title="Search" href="../../../search.html" />
|
||||
<link rel="canonical" href="docs/_modules/cutlass/op/gemm_grouped.html" />
|
||||
|
||||
<!-- Generated with Sphinx 6.1.3 and Furo 2023.03.27 -->
|
||||
<title>cutlass.op.gemm_grouped - CUTLASS Python</title>
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/pygments.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/styles/furo.css?digest=fad236701ea90a88636c2a8c73b44ae642ed2a53" />
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/copybutton.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/tabs.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/styles/furo-extensions.css?digest=30d1aed668e5c3a91c3e3bf6a60b675221979f0e" />
|
||||
|
||||
|
||||
|
||||
|
||||
<style>
|
||||
body {
|
||||
--color-code-background: #eeffcc;
|
||||
--color-code-foreground: black;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
@media not print {
|
||||
body[data-theme="dark"] {
|
||||
--color-code-background: #272822;
|
||||
--color-code-foreground: #f8f8f2;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
@media (prefers-color-scheme: dark) {
|
||||
body:not([data-theme="light"]) {
|
||||
--color-code-background: #272822;
|
||||
--color-code-foreground: #f8f8f2;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
</style></head>
|
||||
<body>
|
||||
|
||||
<script>
|
||||
document.body.dataset.theme = localStorage.getItem("theme") || "auto";
|
||||
</script>
|
||||
|
||||
|
||||
<svg xmlns="http://www.w3.org/2000/svg" style="display: none;">
|
||||
<symbol id="svg-toc" viewBox="0 0 24 24">
|
||||
<title>Contents</title>
|
||||
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 1024 1024">
|
||||
<path d="M408 442h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8zm-8 204c0 4.4 3.6 8 8 8h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56zm504-486H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zm0 632H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zM115.4 518.9L271.7 642c5.8 4.6 14.4.5 14.4-6.9V388.9c0-7.4-8.5-11.5-14.4-6.9L115.4 505.1a8.74 8.74 0 0 0 0 13.8z"/>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-menu" viewBox="0 0 24 24">
|
||||
<title>Menu</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-menu">
|
||||
<line x1="3" y1="12" x2="21" y2="12"></line>
|
||||
<line x1="3" y1="6" x2="21" y2="6"></line>
|
||||
<line x1="3" y1="18" x2="21" y2="18"></line>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-arrow-right" viewBox="0 0 24 24">
|
||||
<title>Expand</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-chevron-right">
|
||||
<polyline points="9 18 15 12 9 6"></polyline>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-sun" viewBox="0 0 24 24">
|
||||
<title>Light mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather-sun">
|
||||
<circle cx="12" cy="12" r="5"></circle>
|
||||
<line x1="12" y1="1" x2="12" y2="3"></line>
|
||||
<line x1="12" y1="21" x2="12" y2="23"></line>
|
||||
<line x1="4.22" y1="4.22" x2="5.64" y2="5.64"></line>
|
||||
<line x1="18.36" y1="18.36" x2="19.78" y2="19.78"></line>
|
||||
<line x1="1" y1="12" x2="3" y2="12"></line>
|
||||
<line x1="21" y1="12" x2="23" y2="12"></line>
|
||||
<line x1="4.22" y1="19.78" x2="5.64" y2="18.36"></line>
|
||||
<line x1="18.36" y1="5.64" x2="19.78" y2="4.22"></line>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-moon" viewBox="0 0 24 24">
|
||||
<title>Dark mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-moon">
|
||||
<path stroke="none" d="M0 0h24v24H0z" fill="none" />
|
||||
<path d="M12 3c.132 0 .263 0 .393 0a7.5 7.5 0 0 0 7.92 12.446a9 9 0 1 1 -8.313 -12.454z" />
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-sun-half" viewBox="0 0 24 24">
|
||||
<title>Auto light/dark mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-shadow">
|
||||
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
|
||||
<circle cx="12" cy="12" r="9" />
|
||||
<path d="M13 12h5" />
|
||||
<path d="M13 15h4" />
|
||||
<path d="M13 18h1" />
|
||||
<path d="M13 9h4" />
|
||||
<path d="M13 6h1" />
|
||||
</svg>
|
||||
</symbol>
|
||||
</svg>
|
||||
|
||||
<input type="checkbox" class="sidebar-toggle" name="__navigation" id="__navigation">
|
||||
<input type="checkbox" class="sidebar-toggle" name="__toc" id="__toc">
|
||||
<label class="overlay sidebar-overlay" for="__navigation">
|
||||
<div class="visually-hidden">Hide navigation sidebar</div>
|
||||
</label>
|
||||
<label class="overlay toc-overlay" for="__toc">
|
||||
<div class="visually-hidden">Hide table of contents sidebar</div>
|
||||
</label>
|
||||
|
||||
|
||||
|
||||
<div class="page">
|
||||
<header class="mobile-header">
|
||||
<div class="header-left">
|
||||
<label class="nav-overlay-icon" for="__navigation">
|
||||
<div class="visually-hidden">Toggle site navigation sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-menu"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
<div class="header-center">
|
||||
<a href="../../../index.html"><div class="brand">CUTLASS Python</div></a>
|
||||
</div>
|
||||
<div class="header-right">
|
||||
<div class="theme-toggle-container theme-toggle-header">
|
||||
<button class="theme-toggle">
|
||||
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
|
||||
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
|
||||
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
|
||||
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
|
||||
</button>
|
||||
</div>
|
||||
<label class="toc-overlay-icon toc-header-icon no-toc" for="__toc">
|
||||
<div class="visually-hidden">Toggle table of contents sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
</header>
|
||||
<aside class="sidebar-drawer">
|
||||
<div class="sidebar-container">
|
||||
|
||||
<div class="sidebar-sticky"><a class="sidebar-brand" href="../../../index.html">
|
||||
|
||||
<div class="sidebar-logo-container">
|
||||
<img class="sidebar-logo only-light" src="../../../_static/cutlass-logo-small.png" alt="Light Logo"/>
|
||||
<img class="sidebar-logo only-dark" src="../../../_static/cutlass-logo-small.png" alt="Dark Logo"/>
|
||||
</div>
|
||||
|
||||
<span class="sidebar-brand-text">CUTLASS Python</span>
|
||||
|
||||
</a><form class="sidebar-search-container" method="get" action="../../../search.html" role="search">
|
||||
<input class="sidebar-search" placeholder="Search" name="q" aria-label="Search">
|
||||
<input type="hidden" name="check_keywords" value="yes">
|
||||
<input type="hidden" name="area" value="default">
|
||||
</form>
|
||||
<div id="searchbox"></div><div class="sidebar-scroll"><div class="sidebar-tree">
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../index.html">Home</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../install.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../externals/00_basic_gemm.html">Getting Started</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contribute.html">Contributing</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python Documentation:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../modules.html">CUTLASS Python API</a><input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" role="switch" type="checkbox"/><label for="toctree-checkbox-1"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l2 has-children"><a class="reference internal" href="../../../cutlass.html">CUTLASS</a><input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" role="switch" type="checkbox"/><label for="toctree-checkbox-2"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../../../cutlass.emit.html">Emitters</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../../../cutlass.op.html">Operations</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../../../cutlass.utils.html">Utilities</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Examples and Tutorials:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../examples.html">Examples</a><input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" role="switch" type="checkbox"/><label for="toctree-checkbox-3"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../../../externals/00_basic_gemm.html">Basic GEMM</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../../../externals/01_epilogue.html">Epilogue</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../../../externals/02_pytorch_extension_grouped_gemm.html">PyTorch Extension</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Reference:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/cutlass">Github</a></li>
|
||||
</ul>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</aside>
|
||||
<div class="main">
|
||||
<div class="content">
|
||||
<div class="article-container">
|
||||
<a href="#" class="back-to-top muted-link">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24">
|
||||
<path d="M13 20h-2V8l-5.5 5.5-1.42-1.42L12 4.16l7.92 7.92-1.42 1.42L13 8v12z"></path>
|
||||
</svg>
|
||||
<span>Back to top</span>
|
||||
</a>
|
||||
<div class="content-icon-container">
|
||||
<div class="theme-toggle-container theme-toggle-content">
|
||||
<button class="theme-toggle">
|
||||
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
|
||||
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
|
||||
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
|
||||
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
|
||||
</button>
|
||||
</div>
|
||||
<label class="toc-overlay-icon toc-content-icon no-toc" for="__toc">
|
||||
<div class="visually-hidden">Toggle table of contents sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
<article role="main">
|
||||
<h1>Source code for cutlass.op.gemm_grouped</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#################################################################################################</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.</span>
|
||||
<span class="c1"># SPDX-License-Identifier: BSD-3-Clause</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Redistribution and use in source and binary forms, with or without</span>
|
||||
<span class="c1"># modification, are permitted provided that the following conditions are met:</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># 1. Redistributions of source code must retain the above copyright notice, this</span>
|
||||
<span class="c1"># list of conditions and the following disclaimer.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># 2. Redistributions in binary form must reproduce the above copyright notice,</span>
|
||||
<span class="c1"># this list of conditions and the following disclaimer in the documentation</span>
|
||||
<span class="c1"># and/or other materials provided with the distribution.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># 3. Neither the name of the copyright holder nor the names of its</span>
|
||||
<span class="c1"># contributors may be used to endorse or promote products derived from</span>
|
||||
<span class="c1"># this software without specific prior written permission.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"</span>
|
||||
<span class="c1"># AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE</span>
|
||||
<span class="c1"># IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE</span>
|
||||
<span class="c1"># DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE</span>
|
||||
<span class="c1"># FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL</span>
|
||||
<span class="c1"># DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR</span>
|
||||
<span class="c1"># SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER</span>
|
||||
<span class="c1"># CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,</span>
|
||||
<span class="c1"># OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE</span>
|
||||
<span class="c1"># OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1">#################################################################################################</span>
|
||||
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Ease-of-use interface for constructing, compiling, and running GEMMs.</span>
|
||||
|
||||
<span class="sd"> The ``GroupedGemm`` interface is meant to allow one to easily instantiate, compile, and run</span>
|
||||
<span class="sd"> grouped GEMM operations in CUTLASS via Python, without specifying many configuration parameters.</span>
|
||||
<span class="sd"> Under the hood, the interface will select sensible default parameters for the many template</span>
|
||||
<span class="sd"> parameters for CUTLASS grouped GEMMs.</span>
|
||||
|
||||
<span class="sd"> Note: optimal performance is not to be expected from this interface. To achieve optimal</span>
|
||||
<span class="sd"> performance, one should specify and tune each configuration parameter.</span>
|
||||
|
||||
<span class="sd"> The simplest example of using this interface is the following:</span>
|
||||
|
||||
<span class="sd"> .. highlight:: python</span>
|
||||
<span class="sd"> .. code-block:: python</span>
|
||||
|
||||
<span class="sd"> # As, Bs, Cs, and Ds are torch/numpy/cupy tensor objects</span>
|
||||
<span class="sd"> plan = cutlass.op.GroupedGemm(element=cutlass.DataType.f16, layout=cutlass.LayoutType.RowMajor)</span>
|
||||
<span class="sd"> plan.run([A0, A1], [B0, B1], [C0, C1], [D0, D1])</span>
|
||||
<span class="sd">"""</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">cutlass_bindings</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">cutlass.backend.gemm_operation</span> <span class="kn">import</span> <span class="p">(</span>
|
||||
<span class="n">GemmGroupedArguments</span><span class="p">,</span>
|
||||
<span class="n">GemmOperationGrouped</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
<span class="kn">from</span> <span class="nn">cutlass.backend.library</span> <span class="kn">import</span> <span class="p">(</span>
|
||||
<span class="n">DataTypeSize</span><span class="p">,</span>
|
||||
<span class="n">SchedulerMode</span><span class="p">,</span>
|
||||
<span class="n">TensorDescription</span><span class="p">,</span>
|
||||
<span class="n">TileDescription</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
<span class="kn">from</span> <span class="nn">cutlass.op.gemm</span> <span class="kn">import</span> <span class="n">Gemm</span>
|
||||
<span class="kn">from</span> <span class="nn">cutlass.utils</span> <span class="kn">import</span> <span class="n">check</span><span class="p">,</span> <span class="n">datatypes</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="GroupedGemm"><a class="viewcode-back" href="../../../cutlass.op.html#cutlass.op.gemm_grouped.GroupedGemm">[docs]</a><span class="k">class</span> <span class="nc">GroupedGemm</span><span class="p">(</span><span class="n">Gemm</span><span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Constructs a ``GroupedGemm`` object.</span>
|
||||
|
||||
<span class="sd"> The data types and layouts of operands A, B, and C, along with the data type of output D</span>
|
||||
<span class="sd"> and that used for accumulation, are bound to the ``GroupedGemm`` object throughout its lifetime --</span>
|
||||
<span class="sd"> these are not to be changed after a ``GroupedGemm`` has been constructed.</span>
|
||||
|
||||
<span class="sd"> The constructor has optional parameters for flexibly setting these parameters. Please see the constructor</span>
|
||||
<span class="sd"> for ``Gemm`` for examples of these.</span>
|
||||
|
||||
<span class="sd"> :param cc: compute capability of device to generate kernels for</span>
|
||||
<span class="sd"> :type cc: int</span>
|
||||
<span class="sd"> :param A: tensor representing data type and layout of operands A</span>
|
||||
<span class="sd"> :param B: tensor representing data type and layout of operands B</span>
|
||||
<span class="sd"> :param C: tensor representing data type and layout of operands C</span>
|
||||
<span class="sd"> :param D: tensor representing data type and layout of operands D</span>
|
||||
<span class="sd"> :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B</span>
|
||||
<span class="sd"> :param beta: scalar parameter beta from GEMM operation that scales operand C</span>
|
||||
<span class="sd"> :param element_accumulator: data type to be used in accumulation of the product of operands A and B</span>
|
||||
<span class="sd"> :type element_accumulator: cutlass.DataType</span>
|
||||
<span class="sd"> :param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type</span>
|
||||
<span class="sd"> :type element: cutlass.DataType</span>
|
||||
<span class="sd"> :param layout: generic layout type to be used for operands A, B, C, and D</span>
|
||||
<span class="sd"> :type layout: cutlass.LayoutType</span>
|
||||
<span class="sd"> :param element_A: data type to be used for operand A</span>
|
||||
<span class="sd"> :type element_A: cutlass.DataType</span>
|
||||
<span class="sd"> :param element_B: data type to be used for operand B</span>
|
||||
<span class="sd"> :type element_B: cutlass.DataType</span>
|
||||
<span class="sd"> :param element_C: data type to be used for operand C</span>
|
||||
<span class="sd"> :type element_C: cutlass.DataType</span>
|
||||
<span class="sd"> :param element_D: data type to be used for operand D</span>
|
||||
<span class="sd"> :type element_D: cutlass.DataType</span>
|
||||
<span class="sd"> :type layout_A: layout of operand A</span>
|
||||
<span class="sd"> :param layout_A: cutlass.LayoutType</span>
|
||||
<span class="sd"> :type layout_B: layout of operand B</span>
|
||||
<span class="sd"> :param layout_B: cutlass.LayoutType</span>
|
||||
<span class="sd"> :type layout_C: layout of operand C</span>
|
||||
<span class="sd"> :param layout_C: cutlass.LayoutType</span>
|
||||
<span class="sd"> :type layout_D: layout of operand D</span>
|
||||
<span class="sd"> :param layout_D: cutlass.LayoutType</span>
|
||||
<span class="sd"> """</span>
|
||||
|
||||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="p">,</span> <span class="n">A</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">B</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">C</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">D</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">alpha</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">beta</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">element_accumulator</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">element</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">layout</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">element_A</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">element_B</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">element_C</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">element_D</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">layout_A</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">layout_B</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">layout_C</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">cc</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
|
||||
<span class="n">A</span><span class="o">=</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="o">=</span><span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="o">=</span><span class="n">C</span><span class="p">,</span> <span class="n">D</span><span class="o">=</span><span class="n">D</span><span class="p">,</span>
|
||||
<span class="n">alpha</span><span class="o">=</span><span class="n">alpha</span><span class="p">,</span> <span class="n">beta</span><span class="o">=</span><span class="n">beta</span><span class="p">,</span>
|
||||
<span class="n">element_accumulator</span><span class="o">=</span><span class="n">element_accumulator</span><span class="p">,</span>
|
||||
<span class="n">element</span><span class="o">=</span><span class="n">element</span><span class="p">,</span> <span class="n">layout</span><span class="o">=</span><span class="n">layout</span><span class="p">,</span>
|
||||
<span class="n">element_A</span><span class="o">=</span><span class="n">element_A</span><span class="p">,</span> <span class="n">element_B</span><span class="o">=</span><span class="n">element_B</span><span class="p">,</span>
|
||||
<span class="n">element_C</span><span class="o">=</span><span class="n">element_C</span><span class="p">,</span> <span class="n">element_D</span><span class="o">=</span><span class="n">element_D</span><span class="p">,</span>
|
||||
<span class="n">layout_A</span><span class="o">=</span><span class="n">layout_A</span><span class="p">,</span> <span class="n">layout_B</span><span class="o">=</span><span class="n">layout_B</span><span class="p">,</span> <span class="n">layout_C</span><span class="o">=</span><span class="n">layout_C</span><span class="p">,</span>
|
||||
<span class="n">cc</span><span class="o">=</span><span class="n">cc</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="c1"># Grouped GEMM specializations for SM90 are currently unavailable. Revert to using SM80</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_cc</span> <span class="o">==</span> <span class="mi">90</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_reset_options</span><span class="p">(</span><span class="mi">80</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_reset_operations</span><span class="p">(</span><span class="n">reset_epilogue</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="s2">"grouped_gemm"</span>
|
||||
|
||||
<span class="nd">@Gemm</span><span class="o">.</span><span class="n">swizzling_functor</span><span class="o">.</span><span class="n">setter</span>
|
||||
<span class="k">def</span> <span class="nf">swizzling_functor</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">swizzling_functor</span><span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Sets the swizzling functor to the type specified by `swizzling_functor`</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s1">'Grouped GEMM does not currently support different swizzling functors'</span><span class="p">)</span>
|
||||
|
||||
<div class="viewcode-block" id="GroupedGemm.construct"><a class="viewcode-back" href="../../../cutlass.op.html#cutlass.op.gemm_grouped.GroupedGemm.construct">[docs]</a> <span class="k">def</span> <span class="nf">construct</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">tile_description</span><span class="p">:</span> <span class="n">TileDescription</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">alignment_A</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">alignment_B</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">alignment_C</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">GemmOperationGrouped</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Constructs a ``cutlass.backend.GemmOperationGrouped`` based on the input parameters and current</span>
|
||||
<span class="sd"> kernel specification of the ``Gemm`` object.</span>
|
||||
|
||||
<span class="sd"> :param tile_description: tile description specifying shapes and operand types to use in the kernel</span>
|
||||
<span class="sd"> :type tile_description: cutlass.backend.TileDescription</span>
|
||||
<span class="sd"> :param alignment_A: alignment of operand A</span>
|
||||
<span class="sd"> :type alignment_A: int</span>
|
||||
<span class="sd"> :param alignment_B: alignment of operand B</span>
|
||||
<span class="sd"> :type alignment_B: int</span>
|
||||
<span class="sd"> :param alignment_C: alignment of operand C</span>
|
||||
<span class="sd"> :type alignment_C: int</span>
|
||||
|
||||
<span class="sd"> :return: operation that was constructed</span>
|
||||
<span class="sd"> :rtype: cutlass.backend.GemmOperationGrouped</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="n">alignment_preference</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">possible_operations</span><span class="o">.</span><span class="n">alignments</span><span class="p">)</span>
|
||||
<span class="n">alignment_A</span> <span class="o">=</span> <span class="n">check</span><span class="o">.</span><span class="n">alignment_or_default</span><span class="p">(</span><span class="n">alignment_A</span><span class="p">,</span> <span class="n">alignment_preference</span><span class="p">)</span>
|
||||
<span class="n">alignment_B</span> <span class="o">=</span> <span class="n">check</span><span class="o">.</span><span class="n">alignment_or_default</span><span class="p">(</span><span class="n">alignment_B</span><span class="p">,</span> <span class="n">alignment_preference</span><span class="p">)</span>
|
||||
<span class="n">alignment_C</span> <span class="o">=</span> <span class="n">check</span><span class="o">.</span><span class="n">alignment_or_default</span><span class="p">(</span><span class="n">alignment_C</span><span class="p">,</span> <span class="n">alignment_preference</span><span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_reset_epilogue_functor_alignment</span><span class="p">(</span><span class="n">alignment_C</span><span class="p">)</span>
|
||||
|
||||
<span class="n">tensor_A</span> <span class="o">=</span> <span class="n">TensorDescription</span><span class="p">(</span>
|
||||
<span class="n">datatypes</span><span class="o">.</span><span class="n">binding_type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_element_a</span><span class="p">),</span>
|
||||
<span class="n">datatypes</span><span class="o">.</span><span class="n">binding_layout</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_layout_a</span><span class="p">),</span>
|
||||
<span class="n">alignment_A</span>
|
||||
<span class="p">)</span>
|
||||
<span class="n">tensor_B</span> <span class="o">=</span> <span class="n">TensorDescription</span><span class="p">(</span>
|
||||
<span class="n">datatypes</span><span class="o">.</span><span class="n">binding_type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_element_b</span><span class="p">),</span>
|
||||
<span class="n">datatypes</span><span class="o">.</span><span class="n">binding_layout</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_layout_b</span><span class="p">),</span>
|
||||
<span class="n">alignment_B</span>
|
||||
<span class="p">)</span>
|
||||
<span class="n">tensor_C</span> <span class="o">=</span> <span class="n">TensorDescription</span><span class="p">(</span>
|
||||
<span class="n">datatypes</span><span class="o">.</span><span class="n">binding_type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_element_c</span><span class="p">),</span>
|
||||
<span class="n">datatypes</span><span class="o">.</span><span class="n">binding_layout</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_layout_c</span><span class="p">),</span>
|
||||
<span class="n">alignment_C</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">tile_description</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">op</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">possible_operations</span><span class="o">.</span><span class="n">operations</span><span class="p">(</span><span class="n">alignment_A</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="n">tile_description</span> <span class="o">=</span> <span class="n">datatypes</span><span class="o">.</span><span class="n">td_from_profiler_op</span><span class="p">(</span><span class="n">op</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">valid</span><span class="p">,</span> <span class="n">err_str</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_valid_tile_description</span><span class="p">(</span><span class="n">tile_description</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="n">valid</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Invalid tile description. </span><span class="si">{</span><span class="n">err_str</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">tile_description</span> <span class="o">=</span> <span class="n">tile_description</span>
|
||||
|
||||
<span class="n">operation</span> <span class="o">=</span> <span class="n">GemmOperationGrouped</span><span class="p">(</span>
|
||||
<span class="n">arch</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">current_cc</span><span class="p">,</span>
|
||||
<span class="n">tile_description</span><span class="o">=</span><span class="n">tile_description</span><span class="p">,</span>
|
||||
<span class="n">A</span><span class="o">=</span><span class="n">tensor_A</span><span class="p">,</span> <span class="n">B</span><span class="o">=</span><span class="n">tensor_B</span><span class="p">,</span> <span class="n">C</span><span class="o">=</span><span class="n">tensor_C</span><span class="p">,</span>
|
||||
<span class="n">epilogue_functor</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">epilogue_functor</span><span class="p">,</span>
|
||||
<span class="n">swizzling_functor</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_swizzling_functor</span><span class="p">,</span>
|
||||
<span class="n">precompute_mode</span><span class="o">=</span><span class="n">SchedulerMode</span><span class="o">.</span><span class="n">Device</span><span class="p">)</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">operation</span></div>
|
||||
|
||||
<div class="viewcode-block" id="GroupedGemm.run"><a class="viewcode-back" href="../../../cutlass.op.html#cutlass.op.gemm_grouped.GroupedGemm.run">[docs]</a> <span class="k">def</span> <span class="nf">run</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">D</span><span class="p">,</span>
|
||||
<span class="n">alpha</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">beta</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">sync</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
|
||||
<span class="n">print_module</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="n">GemmGroupedArguments</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Runs the kernel currently specified.</span>
|
||||
|
||||
<span class="sd"> By default, this call returns only once the kernel has completed. To launch the kernel</span>
|
||||
<span class="sd"> and immediately return, set ``sync=False``. In this case, it is the responsibility of the</span>
|
||||
<span class="sd"> caller to syncrhonize the results of the kernel before attempting to access outputs</span>
|
||||
<span class="sd"> by calling ``sync()`` on the arguments returned from this call.</span>
|
||||
|
||||
<span class="sd"> :param A: list of tensors representing data type and layout of operand A</span>
|
||||
<span class="sd"> :type A: list</span>
|
||||
<span class="sd"> :param B: list of tensors representing data type and layout of operand B</span>
|
||||
<span class="sd"> :type B: list</span>
|
||||
<span class="sd"> :param C: list of tensors representing data type and layout of operand C</span>
|
||||
<span class="sd"> :type C: list</span>
|
||||
<span class="sd"> :param D: list of tensors representing data type and layout of operand D</span>
|
||||
<span class="sd"> :type D: list</span>
|
||||
<span class="sd"> :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B</span>
|
||||
<span class="sd"> :param beta: scalar parameter beta from GEMM operation that scales operand C</span>
|
||||
<span class="sd"> :param sync: whether the call should wait for the kernel to complete before returning</span>
|
||||
<span class="sd"> :type sync: bool</span>
|
||||
<span class="sd"> :param print_module: whether to print the emitted C++ code</span>
|
||||
<span class="sd"> :type print_module: bool</span>
|
||||
|
||||
<span class="sd"> :return: arguments passed in to the kernel</span>
|
||||
<span class="sd"> :rtype: cutlass.backend.GemmGroupedArguments</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">A</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">len</span><span class="p">(</span><span class="n">B</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">len</span><span class="p">(</span><span class="n">A</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">len</span><span class="p">(</span><span class="n">C</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">len</span><span class="p">(</span><span class="n">A</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">len</span><span class="p">(</span><span class="n">D</span><span class="p">):</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s2">"Lengths of A, B, C, and D lists must be equal"</span><span class="p">)</span>
|
||||
|
||||
<span class="n">problem_sizes</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">,</span> <span class="n">Cs</span><span class="p">,</span> <span class="n">Ds</span> <span class="o">=</span> <span class="p">([</span><span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">A</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">4</span><span class="p">))</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">A</span><span class="p">)):</span>
|
||||
<span class="n">As</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_verify_tensor</span><span class="p">(</span><span class="n">A</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">A</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_element_a</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layout_a</span><span class="p">,</span> <span class="s2">"A"</span><span class="p">)</span>
|
||||
<span class="n">Bs</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_verify_tensor</span><span class="p">(</span><span class="n">B</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">B</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_element_b</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layout_b</span><span class="p">,</span> <span class="s2">"B"</span><span class="p">)</span>
|
||||
<span class="n">Cs</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_verify_tensor</span><span class="p">(</span><span class="n">C</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">C</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_element_c</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layout_c</span><span class="p">,</span> <span class="s2">"C"</span><span class="p">)</span>
|
||||
<span class="n">Ds</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_verify_tensor</span><span class="p">(</span><span class="n">D</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">D</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_element_d</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layout_d</span><span class="p">,</span> <span class="s2">"D"</span><span class="p">)</span>
|
||||
<span class="n">problem_sizes</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">gemm</span><span class="o">.</span><span class="n">GemmCoord</span><span class="p">(</span><span class="n">A</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">B</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">A</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]))</span>
|
||||
|
||||
<span class="n">alpha</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_verify_scalar</span><span class="p">(</span><span class="n">alpha</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">alpha</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_element_c</span><span class="p">,</span> <span class="s2">"alpha"</span><span class="p">)</span>
|
||||
<span class="n">beta</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_verify_scalar</span><span class="p">(</span><span class="n">beta</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_element_c</span><span class="p">,</span> <span class="s2">"beta"</span><span class="p">)</span>
|
||||
|
||||
<span class="n">alignment_a</span> <span class="o">=</span> <span class="nb">min</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">possible_operations</span><span class="o">.</span><span class="n">find_alignment</span><span class="p">(</span><span class="n">A</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layout_a</span><span class="p">)</span> <span class="k">for</span> <span class="n">A</span> <span class="ow">in</span> <span class="n">As</span><span class="p">))</span>
|
||||
<span class="n">alignment_b</span> <span class="o">=</span> <span class="nb">min</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">possible_operations</span><span class="o">.</span><span class="n">find_alignment</span><span class="p">(</span><span class="n">B</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layout_b</span><span class="p">)</span> <span class="k">for</span> <span class="n">B</span> <span class="ow">in</span> <span class="n">Bs</span><span class="p">))</span>
|
||||
<span class="n">alignment_c</span> <span class="o">=</span> <span class="nb">min</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">possible_operations</span><span class="o">.</span><span class="n">find_alignment</span><span class="p">(</span><span class="n">C</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layout_c</span><span class="p">)</span> <span class="k">for</span> <span class="n">C</span> <span class="ow">in</span> <span class="n">Cs</span><span class="p">))</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">tile_description</span><span class="p">,</span> <span class="n">alignment_A</span><span class="o">=</span><span class="n">alignment_a</span><span class="p">,</span> <span class="n">alignment_B</span><span class="o">=</span><span class="n">alignment_b</span><span class="p">,</span>
|
||||
<span class="n">alignment_C</span><span class="o">=</span><span class="n">alignment_c</span><span class="p">,</span> <span class="n">print_module</span><span class="o">=</span><span class="n">print_module</span><span class="p">)</span>
|
||||
|
||||
<span class="n">arguments</span> <span class="o">=</span> <span class="n">GemmGroupedArguments</span><span class="p">(</span>
|
||||
<span class="n">operation</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">operation</span><span class="p">,</span>
|
||||
<span class="n">problem_sizes</span><span class="o">=</span><span class="n">problem_sizes</span><span class="p">,</span>
|
||||
<span class="n">A</span><span class="o">=</span><span class="n">As</span><span class="p">,</span> <span class="n">B</span><span class="o">=</span><span class="n">Bs</span><span class="p">,</span> <span class="n">C</span><span class="o">=</span><span class="n">Cs</span><span class="p">,</span> <span class="n">D</span><span class="o">=</span><span class="n">Ds</span><span class="p">,</span>
|
||||
<span class="n">output_op</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">operation</span><span class="o">.</span><span class="n">epilogue_type</span><span class="p">(</span><span class="n">alpha</span><span class="p">,</span> <span class="n">beta</span><span class="p">)</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">operation</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">arguments</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">sync</span><span class="p">:</span>
|
||||
<span class="n">arguments</span><span class="o">.</span><span class="n">sync</span><span class="p">()</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">arguments</span></div></div>
|
||||
</pre></div>
|
||||
</article>
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
<div class="related-pages">
|
||||
|
||||
|
||||
</div>
|
||||
<div class="bottom-of-page">
|
||||
<div class="left-details">
|
||||
<div class="copyright">
|
||||
Copyright © 2023, NVIDIA
|
||||
</div>
|
||||
Made with <a href="https://www.sphinx-doc.org/">Sphinx</a> and <a class="muted-link" href="https://pradyunsg.me">@pradyunsg</a>'s
|
||||
|
||||
<a href="https://github.com/pradyunsg/furo">Furo</a>
|
||||
|
||||
</div>
|
||||
<div class="right-details">
|
||||
<div class="icons">
|
||||
<a class="muted-link " href="https://github.com/NVIDIA/cutlass" aria-label="GitHub">
|
||||
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 16 16">
|
||||
<path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z"></path>
|
||||
</svg>
|
||||
</a>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</footer>
|
||||
</div>
|
||||
<aside class="toc-drawer no-toc">
|
||||
|
||||
|
||||
|
||||
</aside>
|
||||
</div>
|
||||
</div><script data-url_root="../../../" id="documentation_options" src="../../../_static/documentation_options.js"></script>
|
||||
<script src="../../../_static/doctools.js"></script>
|
||||
<script src="../../../_static/sphinx_highlight.js"></script>
|
||||
<script src="../../../_static/scripts/furo.js"></script>
|
||||
<script src="../../../_static/clipboard.min.js"></script>
|
||||
<script src="../../../_static/copybutton.js"></script>
|
||||
<script src="../../../_static/tabs.js"></script>
|
||||
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
400
python/docs/_modules/cutlass/op/op.html
Normal file
400
python/docs/_modules/cutlass/op/op.html
Normal file
@ -0,0 +1,400 @@
|
||||
<!doctype html>
|
||||
<html class="no-js" lang="en">
|
||||
<head><meta charset="utf-8"/>
|
||||
<meta name="viewport" content="width=device-width,initial-scale=1"/>
|
||||
<meta name="color-scheme" content="light dark"><link rel="index" title="Index" href="../../../genindex.html" /><link rel="search" title="Search" href="../../../search.html" />
|
||||
<link rel="canonical" href="docs/_modules/cutlass/op/op.html" />
|
||||
|
||||
<!-- Generated with Sphinx 6.1.3 and Furo 2023.03.27 -->
|
||||
<title>cutlass.op.op - CUTLASS Python</title>
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/pygments.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/styles/furo.css?digest=fad236701ea90a88636c2a8c73b44ae642ed2a53" />
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/copybutton.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/tabs.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/styles/furo-extensions.css?digest=30d1aed668e5c3a91c3e3bf6a60b675221979f0e" />
|
||||
|
||||
|
||||
|
||||
|
||||
<style>
|
||||
body {
|
||||
--color-code-background: #eeffcc;
|
||||
--color-code-foreground: black;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
@media not print {
|
||||
body[data-theme="dark"] {
|
||||
--color-code-background: #272822;
|
||||
--color-code-foreground: #f8f8f2;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
@media (prefers-color-scheme: dark) {
|
||||
body:not([data-theme="light"]) {
|
||||
--color-code-background: #272822;
|
||||
--color-code-foreground: #f8f8f2;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
</style></head>
|
||||
<body>
|
||||
|
||||
<script>
|
||||
document.body.dataset.theme = localStorage.getItem("theme") || "auto";
|
||||
</script>
|
||||
|
||||
|
||||
<svg xmlns="http://www.w3.org/2000/svg" style="display: none;">
|
||||
<symbol id="svg-toc" viewBox="0 0 24 24">
|
||||
<title>Contents</title>
|
||||
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 1024 1024">
|
||||
<path d="M408 442h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8zm-8 204c0 4.4 3.6 8 8 8h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56zm504-486H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zm0 632H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zM115.4 518.9L271.7 642c5.8 4.6 14.4.5 14.4-6.9V388.9c0-7.4-8.5-11.5-14.4-6.9L115.4 505.1a8.74 8.74 0 0 0 0 13.8z"/>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-menu" viewBox="0 0 24 24">
|
||||
<title>Menu</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-menu">
|
||||
<line x1="3" y1="12" x2="21" y2="12"></line>
|
||||
<line x1="3" y1="6" x2="21" y2="6"></line>
|
||||
<line x1="3" y1="18" x2="21" y2="18"></line>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-arrow-right" viewBox="0 0 24 24">
|
||||
<title>Expand</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-chevron-right">
|
||||
<polyline points="9 18 15 12 9 6"></polyline>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-sun" viewBox="0 0 24 24">
|
||||
<title>Light mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather-sun">
|
||||
<circle cx="12" cy="12" r="5"></circle>
|
||||
<line x1="12" y1="1" x2="12" y2="3"></line>
|
||||
<line x1="12" y1="21" x2="12" y2="23"></line>
|
||||
<line x1="4.22" y1="4.22" x2="5.64" y2="5.64"></line>
|
||||
<line x1="18.36" y1="18.36" x2="19.78" y2="19.78"></line>
|
||||
<line x1="1" y1="12" x2="3" y2="12"></line>
|
||||
<line x1="21" y1="12" x2="23" y2="12"></line>
|
||||
<line x1="4.22" y1="19.78" x2="5.64" y2="18.36"></line>
|
||||
<line x1="18.36" y1="5.64" x2="19.78" y2="4.22"></line>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-moon" viewBox="0 0 24 24">
|
||||
<title>Dark mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-moon">
|
||||
<path stroke="none" d="M0 0h24v24H0z" fill="none" />
|
||||
<path d="M12 3c.132 0 .263 0 .393 0a7.5 7.5 0 0 0 7.92 12.446a9 9 0 1 1 -8.313 -12.454z" />
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-sun-half" viewBox="0 0 24 24">
|
||||
<title>Auto light/dark mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-shadow">
|
||||
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
|
||||
<circle cx="12" cy="12" r="9" />
|
||||
<path d="M13 12h5" />
|
||||
<path d="M13 15h4" />
|
||||
<path d="M13 18h1" />
|
||||
<path d="M13 9h4" />
|
||||
<path d="M13 6h1" />
|
||||
</svg>
|
||||
</symbol>
|
||||
</svg>
|
||||
|
||||
<input type="checkbox" class="sidebar-toggle" name="__navigation" id="__navigation">
|
||||
<input type="checkbox" class="sidebar-toggle" name="__toc" id="__toc">
|
||||
<label class="overlay sidebar-overlay" for="__navigation">
|
||||
<div class="visually-hidden">Hide navigation sidebar</div>
|
||||
</label>
|
||||
<label class="overlay toc-overlay" for="__toc">
|
||||
<div class="visually-hidden">Hide table of contents sidebar</div>
|
||||
</label>
|
||||
|
||||
|
||||
|
||||
<div class="page">
|
||||
<header class="mobile-header">
|
||||
<div class="header-left">
|
||||
<label class="nav-overlay-icon" for="__navigation">
|
||||
<div class="visually-hidden">Toggle site navigation sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-menu"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
<div class="header-center">
|
||||
<a href="../../../index.html"><div class="brand">CUTLASS Python</div></a>
|
||||
</div>
|
||||
<div class="header-right">
|
||||
<div class="theme-toggle-container theme-toggle-header">
|
||||
<button class="theme-toggle">
|
||||
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
|
||||
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
|
||||
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
|
||||
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
|
||||
</button>
|
||||
</div>
|
||||
<label class="toc-overlay-icon toc-header-icon no-toc" for="__toc">
|
||||
<div class="visually-hidden">Toggle table of contents sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
</header>
|
||||
<aside class="sidebar-drawer">
|
||||
<div class="sidebar-container">
|
||||
|
||||
<div class="sidebar-sticky"><a class="sidebar-brand" href="../../../index.html">
|
||||
|
||||
<div class="sidebar-logo-container">
|
||||
<img class="sidebar-logo only-light" src="../../../_static/cutlass-logo-small.png" alt="Light Logo"/>
|
||||
<img class="sidebar-logo only-dark" src="../../../_static/cutlass-logo-small.png" alt="Dark Logo"/>
|
||||
</div>
|
||||
|
||||
<span class="sidebar-brand-text">CUTLASS Python</span>
|
||||
|
||||
</a><form class="sidebar-search-container" method="get" action="../../../search.html" role="search">
|
||||
<input class="sidebar-search" placeholder="Search" name="q" aria-label="Search">
|
||||
<input type="hidden" name="check_keywords" value="yes">
|
||||
<input type="hidden" name="area" value="default">
|
||||
</form>
|
||||
<div id="searchbox"></div><div class="sidebar-scroll"><div class="sidebar-tree">
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../index.html">Home</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../install.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../externals/00_basic_gemm.html">Getting Started</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contribute.html">Contributing</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python Documentation:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../modules.html">CUTLASS Python API</a><input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" role="switch" type="checkbox"/><label for="toctree-checkbox-1"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l2 has-children"><a class="reference internal" href="../../../cutlass.html">CUTLASS</a><input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" role="switch" type="checkbox"/><label for="toctree-checkbox-2"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../../../cutlass.emit.html">Emitters</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../../../cutlass.op.html">Operations</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../../../cutlass.utils.html">Utilities</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Examples and Tutorials:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../examples.html">Examples</a><input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" role="switch" type="checkbox"/><label for="toctree-checkbox-3"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../../../externals/00_basic_gemm.html">Basic GEMM</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../../../externals/01_epilogue.html">Epilogue</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../../../externals/02_pytorch_extension_grouped_gemm.html">PyTorch Extension</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Reference:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/cutlass">Github</a></li>
|
||||
</ul>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</aside>
|
||||
<div class="main">
|
||||
<div class="content">
|
||||
<div class="article-container">
|
||||
<a href="#" class="back-to-top muted-link">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24">
|
||||
<path d="M13 20h-2V8l-5.5 5.5-1.42-1.42L12 4.16l7.92 7.92-1.42 1.42L13 8v12z"></path>
|
||||
</svg>
|
||||
<span>Back to top</span>
|
||||
</a>
|
||||
<div class="content-icon-container">
|
||||
<div class="theme-toggle-container theme-toggle-content">
|
||||
<button class="theme-toggle">
|
||||
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
|
||||
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
|
||||
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
|
||||
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
|
||||
</button>
|
||||
</div>
|
||||
<label class="toc-overlay-icon toc-content-icon no-toc" for="__toc">
|
||||
<div class="visually-hidden">Toggle table of contents sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
<article role="main">
|
||||
<h1>Source code for cutlass.op.op</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#################################################################################################</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.</span>
|
||||
<span class="c1"># SPDX-License-Identifier: BSD-3-Clause</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Redistribution and use in source and binary forms, with or without</span>
|
||||
<span class="c1"># modification, are permitted provided that the following conditions are met:</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># 1. Redistributions of source code must retain the above copyright notice, this</span>
|
||||
<span class="c1"># list of conditions and the following disclaimer.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># 2. Redistributions in binary form must reproduce the above copyright notice,</span>
|
||||
<span class="c1"># this list of conditions and the following disclaimer in the documentation</span>
|
||||
<span class="c1"># and/or other materials provided with the distribution.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># 3. Neither the name of the copyright holder nor the names of its</span>
|
||||
<span class="c1"># contributors may be used to endorse or promote products derived from</span>
|
||||
<span class="c1"># this software without specific prior written permission.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"</span>
|
||||
<span class="c1"># AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE</span>
|
||||
<span class="c1"># IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE</span>
|
||||
<span class="c1"># DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE</span>
|
||||
<span class="c1"># FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL</span>
|
||||
<span class="c1"># DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR</span>
|
||||
<span class="c1"># SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER</span>
|
||||
<span class="c1"># CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,</span>
|
||||
<span class="c1"># OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE</span>
|
||||
<span class="c1"># OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1">#################################################################################################</span>
|
||||
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd">Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d)</span>
|
||||
<span class="sd">"""</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">bisect</span> <span class="kn">import</span> <span class="n">bisect_left</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">cutlass</span> <span class="kn">import</span> <span class="n">option_registry</span>
|
||||
<span class="kn">from</span> <span class="nn">cutlass.backend.utils.device</span> <span class="kn">import</span> <span class="n">device_cc</span>
|
||||
<span class="kn">from</span> <span class="nn">cutlass.epilogue</span> <span class="kn">import</span> <span class="n">get_activations</span>
|
||||
<span class="kn">from</span> <span class="nn">cutlass.library_defaults</span> <span class="kn">import</span> <span class="n">_generator_ccs</span>
|
||||
<span class="kn">from</span> <span class="nn">cutlass.swizzle</span> <span class="kn">import</span> <span class="n">get_swizzling_functors</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="OperationBase"><a class="viewcode-back" href="../../../cutlass.op.html#cutlass.op.op.OperationBase">[docs]</a><span class="k">class</span> <span class="nc">OperationBase</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d)</span>
|
||||
<span class="sd"> """</span>
|
||||
|
||||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">cc</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">kernel_cc</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> :param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90</span>
|
||||
<span class="sd"> :type cc: int</span>
|
||||
<span class="sd"> :param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80</span>
|
||||
<span class="sd"> :type kernel_cc: int</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">cc</span> <span class="o">=</span> <span class="n">cc</span> <span class="k">if</span> <span class="n">cc</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">device_cc</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">specified_kernel_cc</span> <span class="o">=</span> <span class="n">kernel_cc</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">current_cc</span> <span class="o">=</span> <span class="n">kernel_cc</span> <span class="k">if</span> <span class="n">kernel_cc</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">_find_closest_cc</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cc</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">tile_description</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">options</span> <span class="o">=</span> <span class="n">option_registry</span><span class="o">.</span><span class="n">options_for_cc</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">current_cc</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">options</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Invalid or unsupported compute capability: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">current_cc</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_find_closest_cc</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">cc</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Returns the closest CC in _generator_ccs less than or equal to `cc`</span>
|
||||
|
||||
<span class="sd"> :param cc: compute capability to query</span>
|
||||
<span class="sd"> :type cc: int</span>
|
||||
|
||||
<span class="sd"> :returns: closest CC in _generator_ccs less than or equal to `cc`</span>
|
||||
<span class="sd"> :rtype: int</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">if</span> <span class="n">cc</span> <span class="ow">in</span> <span class="n">_generator_ccs</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">cc</span>
|
||||
|
||||
<span class="c1"># Find closest CC lower than this CC</span>
|
||||
<span class="n">idx</span> <span class="o">=</span> <span class="n">bisect_left</span><span class="p">(</span><span class="n">_generator_ccs</span><span class="p">,</span> <span class="n">cc</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">idx</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="sa">f</span><span class="s1">'No valid CC to fall back to for </span><span class="si">{</span><span class="n">cc</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">_generator_ccs</span><span class="p">[</span><span class="n">idx</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
|
||||
|
||||
<div class="viewcode-block" id="OperationBase.activations"><a class="viewcode-back" href="../../../cutlass.op.html#cutlass.op.op.OperationBase.activations">[docs]</a> <span class="k">def</span> <span class="nf">activations</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">list</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Returns possible activation functions that can be used</span>
|
||||
|
||||
<span class="sd"> :return: list of activation functions that can be used</span>
|
||||
<span class="sd"> :rtype: list</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">return</span> <span class="n">get_activations</span><span class="p">()</span></div>
|
||||
|
||||
<div class="viewcode-block" id="OperationBase.swizzling_functors"><a class="viewcode-back" href="../../../cutlass.op.html#cutlass.op.op.OperationBase.swizzling_functors">[docs]</a> <span class="k">def</span> <span class="nf">swizzling_functors</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">list</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Returns possible swizzling functions that can be used</span>
|
||||
|
||||
<span class="sd"> :return: list of swizzling functions that can be used</span>
|
||||
<span class="sd"> :rtype: list</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">return</span> <span class="n">get_swizzling_functors</span><span class="p">()</span></div>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_reset_options</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">cc</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Resets the kernel options based on cc</span>
|
||||
|
||||
<span class="sd"> :param cc: compute capability to reset to</span>
|
||||
<span class="sd"> :type cc: int</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">if</span> <span class="n">cc</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_cc</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">cc</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">_generator_ccs</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="sa">f</span><span class="s1">'Invalid CC for CUTLASS kernels: </span><span class="si">{</span><span class="n">cc</span><span class="si">}</span><span class="s1">.'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">current_cc</span> <span class="o">=</span> <span class="n">cc</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">options</span> <span class="o">=</span> <span class="n">option_registry</span><span class="o">.</span><span class="n">options_for_cc</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">current_cc</span><span class="p">)</span></div>
|
||||
</pre></div>
|
||||
</article>
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
<div class="related-pages">
|
||||
|
||||
|
||||
</div>
|
||||
<div class="bottom-of-page">
|
||||
<div class="left-details">
|
||||
<div class="copyright">
|
||||
Copyright © 2023, NVIDIA
|
||||
</div>
|
||||
Made with <a href="https://www.sphinx-doc.org/">Sphinx</a> and <a class="muted-link" href="https://pradyunsg.me">@pradyunsg</a>'s
|
||||
|
||||
<a href="https://github.com/pradyunsg/furo">Furo</a>
|
||||
|
||||
</div>
|
||||
<div class="right-details">
|
||||
<div class="icons">
|
||||
<a class="muted-link " href="https://github.com/NVIDIA/cutlass" aria-label="GitHub">
|
||||
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 16 16">
|
||||
<path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z"></path>
|
||||
</svg>
|
||||
</a>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</footer>
|
||||
</div>
|
||||
<aside class="toc-drawer no-toc">
|
||||
|
||||
|
||||
|
||||
</aside>
|
||||
</div>
|
||||
</div><script data-url_root="../../../" id="documentation_options" src="../../../_static/documentation_options.js"></script>
|
||||
<script src="../../../_static/doctools.js"></script>
|
||||
<script src="../../../_static/sphinx_highlight.js"></script>
|
||||
<script src="../../../_static/scripts/furo.js"></script>
|
||||
<script src="../../../_static/clipboard.min.js"></script>
|
||||
<script src="../../../_static/copybutton.js"></script>
|
||||
<script src="../../../_static/tabs.js"></script>
|
||||
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
350
python/docs/_modules/cutlass/swizzle.html
Normal file
350
python/docs/_modules/cutlass/swizzle.html
Normal file
@ -0,0 +1,350 @@
|
||||
<!doctype html>
|
||||
<html class="no-js" lang="en">
|
||||
<head><meta charset="utf-8"/>
|
||||
<meta name="viewport" content="width=device-width,initial-scale=1"/>
|
||||
<meta name="color-scheme" content="light dark"><link rel="index" title="Index" href="../../genindex.html" /><link rel="search" title="Search" href="../../search.html" />
|
||||
<link rel="canonical" href="docs/_modules/cutlass/swizzle.html" />
|
||||
|
||||
<!-- Generated with Sphinx 6.1.3 and Furo 2023.03.27 -->
|
||||
<title>cutlass.swizzle - CUTLASS Python</title>
|
||||
<link rel="stylesheet" type="text/css" href="../../_static/pygments.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../../_static/styles/furo.css?digest=fad236701ea90a88636c2a8c73b44ae642ed2a53" />
|
||||
<link rel="stylesheet" type="text/css" href="../../_static/copybutton.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../../_static/tabs.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../../_static/styles/furo-extensions.css?digest=30d1aed668e5c3a91c3e3bf6a60b675221979f0e" />
|
||||
|
||||
|
||||
|
||||
|
||||
<style>
|
||||
body {
|
||||
--color-code-background: #eeffcc;
|
||||
--color-code-foreground: black;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
@media not print {
|
||||
body[data-theme="dark"] {
|
||||
--color-code-background: #272822;
|
||||
--color-code-foreground: #f8f8f2;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
@media (prefers-color-scheme: dark) {
|
||||
body:not([data-theme="light"]) {
|
||||
--color-code-background: #272822;
|
||||
--color-code-foreground: #f8f8f2;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
</style></head>
|
||||
<body>
|
||||
|
||||
<script>
|
||||
document.body.dataset.theme = localStorage.getItem("theme") || "auto";
|
||||
</script>
|
||||
|
||||
|
||||
<svg xmlns="http://www.w3.org/2000/svg" style="display: none;">
|
||||
<symbol id="svg-toc" viewBox="0 0 24 24">
|
||||
<title>Contents</title>
|
||||
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 1024 1024">
|
||||
<path d="M408 442h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8zm-8 204c0 4.4 3.6 8 8 8h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56zm504-486H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zm0 632H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zM115.4 518.9L271.7 642c5.8 4.6 14.4.5 14.4-6.9V388.9c0-7.4-8.5-11.5-14.4-6.9L115.4 505.1a8.74 8.74 0 0 0 0 13.8z"/>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-menu" viewBox="0 0 24 24">
|
||||
<title>Menu</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-menu">
|
||||
<line x1="3" y1="12" x2="21" y2="12"></line>
|
||||
<line x1="3" y1="6" x2="21" y2="6"></line>
|
||||
<line x1="3" y1="18" x2="21" y2="18"></line>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-arrow-right" viewBox="0 0 24 24">
|
||||
<title>Expand</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-chevron-right">
|
||||
<polyline points="9 18 15 12 9 6"></polyline>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-sun" viewBox="0 0 24 24">
|
||||
<title>Light mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather-sun">
|
||||
<circle cx="12" cy="12" r="5"></circle>
|
||||
<line x1="12" y1="1" x2="12" y2="3"></line>
|
||||
<line x1="12" y1="21" x2="12" y2="23"></line>
|
||||
<line x1="4.22" y1="4.22" x2="5.64" y2="5.64"></line>
|
||||
<line x1="18.36" y1="18.36" x2="19.78" y2="19.78"></line>
|
||||
<line x1="1" y1="12" x2="3" y2="12"></line>
|
||||
<line x1="21" y1="12" x2="23" y2="12"></line>
|
||||
<line x1="4.22" y1="19.78" x2="5.64" y2="18.36"></line>
|
||||
<line x1="18.36" y1="5.64" x2="19.78" y2="4.22"></line>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-moon" viewBox="0 0 24 24">
|
||||
<title>Dark mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-moon">
|
||||
<path stroke="none" d="M0 0h24v24H0z" fill="none" />
|
||||
<path d="M12 3c.132 0 .263 0 .393 0a7.5 7.5 0 0 0 7.92 12.446a9 9 0 1 1 -8.313 -12.454z" />
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-sun-half" viewBox="0 0 24 24">
|
||||
<title>Auto light/dark mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-shadow">
|
||||
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
|
||||
<circle cx="12" cy="12" r="9" />
|
||||
<path d="M13 12h5" />
|
||||
<path d="M13 15h4" />
|
||||
<path d="M13 18h1" />
|
||||
<path d="M13 9h4" />
|
||||
<path d="M13 6h1" />
|
||||
</svg>
|
||||
</symbol>
|
||||
</svg>
|
||||
|
||||
<input type="checkbox" class="sidebar-toggle" name="__navigation" id="__navigation">
|
||||
<input type="checkbox" class="sidebar-toggle" name="__toc" id="__toc">
|
||||
<label class="overlay sidebar-overlay" for="__navigation">
|
||||
<div class="visually-hidden">Hide navigation sidebar</div>
|
||||
</label>
|
||||
<label class="overlay toc-overlay" for="__toc">
|
||||
<div class="visually-hidden">Hide table of contents sidebar</div>
|
||||
</label>
|
||||
|
||||
|
||||
|
||||
<div class="page">
|
||||
<header class="mobile-header">
|
||||
<div class="header-left">
|
||||
<label class="nav-overlay-icon" for="__navigation">
|
||||
<div class="visually-hidden">Toggle site navigation sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-menu"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
<div class="header-center">
|
||||
<a href="../../index.html"><div class="brand">CUTLASS Python</div></a>
|
||||
</div>
|
||||
<div class="header-right">
|
||||
<div class="theme-toggle-container theme-toggle-header">
|
||||
<button class="theme-toggle">
|
||||
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
|
||||
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
|
||||
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
|
||||
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
|
||||
</button>
|
||||
</div>
|
||||
<label class="toc-overlay-icon toc-header-icon no-toc" for="__toc">
|
||||
<div class="visually-hidden">Toggle table of contents sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
</header>
|
||||
<aside class="sidebar-drawer">
|
||||
<div class="sidebar-container">
|
||||
|
||||
<div class="sidebar-sticky"><a class="sidebar-brand" href="../../index.html">
|
||||
|
||||
<div class="sidebar-logo-container">
|
||||
<img class="sidebar-logo only-light" src="../../_static/cutlass-logo-small.png" alt="Light Logo"/>
|
||||
<img class="sidebar-logo only-dark" src="../../_static/cutlass-logo-small.png" alt="Dark Logo"/>
|
||||
</div>
|
||||
|
||||
<span class="sidebar-brand-text">CUTLASS Python</span>
|
||||
|
||||
</a><form class="sidebar-search-container" method="get" action="../../search.html" role="search">
|
||||
<input class="sidebar-search" placeholder="Search" name="q" aria-label="Search">
|
||||
<input type="hidden" name="check_keywords" value="yes">
|
||||
<input type="hidden" name="area" value="default">
|
||||
</form>
|
||||
<div id="searchbox"></div><div class="sidebar-scroll"><div class="sidebar-tree">
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../index.html">Home</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../install.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../externals/00_basic_gemm.html">Getting Started</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../contribute.html">Contributing</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python Documentation:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../../modules.html">CUTLASS Python API</a><input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" role="switch" type="checkbox"/><label for="toctree-checkbox-1"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l2 has-children"><a class="reference internal" href="../../cutlass.html">CUTLASS</a><input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" role="switch" type="checkbox"/><label for="toctree-checkbox-2"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../../cutlass.emit.html">Emitters</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../../cutlass.op.html">Operations</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../../cutlass.utils.html">Utilities</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Examples and Tutorials:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../../examples.html">Examples</a><input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" role="switch" type="checkbox"/><label for="toctree-checkbox-3"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../../externals/00_basic_gemm.html">Basic GEMM</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../../externals/01_epilogue.html">Epilogue</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../../externals/02_pytorch_extension_grouped_gemm.html">PyTorch Extension</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Reference:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/cutlass">Github</a></li>
|
||||
</ul>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</aside>
|
||||
<div class="main">
|
||||
<div class="content">
|
||||
<div class="article-container">
|
||||
<a href="#" class="back-to-top muted-link">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24">
|
||||
<path d="M13 20h-2V8l-5.5 5.5-1.42-1.42L12 4.16l7.92 7.92-1.42 1.42L13 8v12z"></path>
|
||||
</svg>
|
||||
<span>Back to top</span>
|
||||
</a>
|
||||
<div class="content-icon-container">
|
||||
<div class="theme-toggle-container theme-toggle-content">
|
||||
<button class="theme-toggle">
|
||||
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
|
||||
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
|
||||
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
|
||||
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
|
||||
</button>
|
||||
</div>
|
||||
<label class="toc-overlay-icon toc-content-icon no-toc" for="__toc">
|
||||
<div class="visually-hidden">Toggle table of contents sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
<article role="main">
|
||||
<h1>Source code for cutlass.swizzle</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#################################################################################################</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.</span>
|
||||
<span class="c1"># SPDX-License-Identifier: BSD-3-Clause</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Redistribution and use in source and binary forms, with or without</span>
|
||||
<span class="c1"># modification, are permitted provided that the following conditions are met:</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># 1. Redistributions of source code must retain the above copyright notice, this</span>
|
||||
<span class="c1"># list of conditions and the following disclaimer.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># 2. Redistributions in binary form must reproduce the above copyright notice,</span>
|
||||
<span class="c1"># this list of conditions and the following disclaimer in the documentation</span>
|
||||
<span class="c1"># and/or other materials provided with the distribution.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># 3. Neither the name of the copyright holder nor the names of its</span>
|
||||
<span class="c1"># contributors may be used to endorse or promote products derived from</span>
|
||||
<span class="c1"># this software without specific prior written permission.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"</span>
|
||||
<span class="c1"># AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE</span>
|
||||
<span class="c1"># IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE</span>
|
||||
<span class="c1"># DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE</span>
|
||||
<span class="c1"># FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL</span>
|
||||
<span class="c1"># DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR</span>
|
||||
<span class="c1"># SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER</span>
|
||||
<span class="c1"># CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,</span>
|
||||
<span class="c1"># OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE</span>
|
||||
<span class="c1"># OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1">#################################################################################################</span>
|
||||
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd">Registry of swizzling functions</span>
|
||||
<span class="sd">"""</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">cutlass_bindings</span>
|
||||
|
||||
<span class="n">IdentitySwizzle1</span> <span class="o">=</span> <span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">IdentitySwizzle1</span>
|
||||
<span class="n">IdentitySwizzle2</span> <span class="o">=</span> <span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">IdentitySwizzle2</span>
|
||||
<span class="n">IdentitySwizzle4</span> <span class="o">=</span> <span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">IdentitySwizzle4</span>
|
||||
<span class="n">IdentitySwizzle8</span> <span class="o">=</span> <span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">IdentitySwizzle8</span>
|
||||
<span class="n">HorizontalSwizzle</span> <span class="o">=</span> <span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">HorizontalSwizzle</span>
|
||||
<span class="n">BatchedIdentitySwizzle</span> <span class="o">=</span> <span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">BatchedIdentitySwizzle</span>
|
||||
<span class="n">ThreadblockSwizzleStreamK</span> <span class="o">=</span> <span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">ThreadblockSwizzleStreamK</span>
|
||||
<span class="n">StridedDgradIdentitySwizzle1</span> <span class="o">=</span> <span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">StridedDgradIdentitySwizzle1</span>
|
||||
<span class="n">StridedDgradIdentitySwizzle4</span> <span class="o">=</span> <span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">StridedDgradIdentitySwizzle4</span>
|
||||
<span class="n">StridedDgradHorizontalSwizzle</span> <span class="o">=</span> <span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">StridedDgradHorizontalSwizzle</span>
|
||||
|
||||
|
||||
<span class="n">_swizzling_functors</span> <span class="o">=</span> <span class="p">[</span>
|
||||
<span class="n">IdentitySwizzle1</span><span class="p">,</span>
|
||||
<span class="n">IdentitySwizzle2</span><span class="p">,</span>
|
||||
<span class="n">IdentitySwizzle4</span><span class="p">,</span>
|
||||
<span class="n">IdentitySwizzle8</span><span class="p">,</span>
|
||||
<span class="n">HorizontalSwizzle</span><span class="p">,</span>
|
||||
<span class="n">BatchedIdentitySwizzle</span><span class="p">,</span>
|
||||
<span class="n">ThreadblockSwizzleStreamK</span><span class="p">,</span>
|
||||
<span class="n">StridedDgradIdentitySwizzle1</span><span class="p">,</span>
|
||||
<span class="n">StridedDgradIdentitySwizzle4</span><span class="p">,</span>
|
||||
<span class="n">StridedDgradHorizontalSwizzle</span><span class="p">,</span>
|
||||
<span class="p">]</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="get_swizzling_functors"><a class="viewcode-back" href="../../cutlass.html#cutlass.swizzle.get_swizzling_functors">[docs]</a><span class="k">def</span> <span class="nf">get_swizzling_functors</span><span class="p">():</span>
|
||||
<span class="k">return</span> <span class="n">_swizzling_functors</span></div>
|
||||
</pre></div>
|
||||
</article>
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
<div class="related-pages">
|
||||
|
||||
|
||||
</div>
|
||||
<div class="bottom-of-page">
|
||||
<div class="left-details">
|
||||
<div class="copyright">
|
||||
Copyright © 2023, NVIDIA
|
||||
</div>
|
||||
Made with <a href="https://www.sphinx-doc.org/">Sphinx</a> and <a class="muted-link" href="https://pradyunsg.me">@pradyunsg</a>'s
|
||||
|
||||
<a href="https://github.com/pradyunsg/furo">Furo</a>
|
||||
|
||||
</div>
|
||||
<div class="right-details">
|
||||
<div class="icons">
|
||||
<a class="muted-link " href="https://github.com/NVIDIA/cutlass" aria-label="GitHub">
|
||||
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 16 16">
|
||||
<path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z"></path>
|
||||
</svg>
|
||||
</a>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</footer>
|
||||
</div>
|
||||
<aside class="toc-drawer no-toc">
|
||||
|
||||
|
||||
|
||||
</aside>
|
||||
</div>
|
||||
</div><script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
<script src="../../_static/sphinx_highlight.js"></script>
|
||||
<script src="../../_static/scripts/furo.js"></script>
|
||||
<script src="../../_static/clipboard.min.js"></script>
|
||||
<script src="../../_static/copybutton.js"></script>
|
||||
<script src="../../_static/tabs.js"></script>
|
||||
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
476
python/docs/_modules/cutlass/utils/check.html
Normal file
476
python/docs/_modules/cutlass/utils/check.html
Normal file
@ -0,0 +1,476 @@
|
||||
<!doctype html>
|
||||
<html class="no-js" lang="en">
|
||||
<head><meta charset="utf-8"/>
|
||||
<meta name="viewport" content="width=device-width,initial-scale=1"/>
|
||||
<meta name="color-scheme" content="light dark"><link rel="index" title="Index" href="../../../genindex.html" /><link rel="search" title="Search" href="../../../search.html" />
|
||||
<link rel="canonical" href="docs/_modules/cutlass/utils/check.html" />
|
||||
|
||||
<!-- Generated with Sphinx 6.1.3 and Furo 2023.03.27 -->
|
||||
<title>cutlass.utils.check - CUTLASS Python</title>
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/pygments.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/styles/furo.css?digest=fad236701ea90a88636c2a8c73b44ae642ed2a53" />
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/copybutton.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/tabs.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/styles/furo-extensions.css?digest=30d1aed668e5c3a91c3e3bf6a60b675221979f0e" />
|
||||
|
||||
|
||||
|
||||
|
||||
<style>
|
||||
body {
|
||||
--color-code-background: #eeffcc;
|
||||
--color-code-foreground: black;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
@media not print {
|
||||
body[data-theme="dark"] {
|
||||
--color-code-background: #272822;
|
||||
--color-code-foreground: #f8f8f2;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
@media (prefers-color-scheme: dark) {
|
||||
body:not([data-theme="light"]) {
|
||||
--color-code-background: #272822;
|
||||
--color-code-foreground: #f8f8f2;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
</style></head>
|
||||
<body>
|
||||
|
||||
<script>
|
||||
document.body.dataset.theme = localStorage.getItem("theme") || "auto";
|
||||
</script>
|
||||
|
||||
|
||||
<svg xmlns="http://www.w3.org/2000/svg" style="display: none;">
|
||||
<symbol id="svg-toc" viewBox="0 0 24 24">
|
||||
<title>Contents</title>
|
||||
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 1024 1024">
|
||||
<path d="M408 442h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8zm-8 204c0 4.4 3.6 8 8 8h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56zm504-486H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zm0 632H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zM115.4 518.9L271.7 642c5.8 4.6 14.4.5 14.4-6.9V388.9c0-7.4-8.5-11.5-14.4-6.9L115.4 505.1a8.74 8.74 0 0 0 0 13.8z"/>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-menu" viewBox="0 0 24 24">
|
||||
<title>Menu</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-menu">
|
||||
<line x1="3" y1="12" x2="21" y2="12"></line>
|
||||
<line x1="3" y1="6" x2="21" y2="6"></line>
|
||||
<line x1="3" y1="18" x2="21" y2="18"></line>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-arrow-right" viewBox="0 0 24 24">
|
||||
<title>Expand</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-chevron-right">
|
||||
<polyline points="9 18 15 12 9 6"></polyline>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-sun" viewBox="0 0 24 24">
|
||||
<title>Light mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather-sun">
|
||||
<circle cx="12" cy="12" r="5"></circle>
|
||||
<line x1="12" y1="1" x2="12" y2="3"></line>
|
||||
<line x1="12" y1="21" x2="12" y2="23"></line>
|
||||
<line x1="4.22" y1="4.22" x2="5.64" y2="5.64"></line>
|
||||
<line x1="18.36" y1="18.36" x2="19.78" y2="19.78"></line>
|
||||
<line x1="1" y1="12" x2="3" y2="12"></line>
|
||||
<line x1="21" y1="12" x2="23" y2="12"></line>
|
||||
<line x1="4.22" y1="19.78" x2="5.64" y2="18.36"></line>
|
||||
<line x1="18.36" y1="5.64" x2="19.78" y2="4.22"></line>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-moon" viewBox="0 0 24 24">
|
||||
<title>Dark mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-moon">
|
||||
<path stroke="none" d="M0 0h24v24H0z" fill="none" />
|
||||
<path d="M12 3c.132 0 .263 0 .393 0a7.5 7.5 0 0 0 7.92 12.446a9 9 0 1 1 -8.313 -12.454z" />
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-sun-half" viewBox="0 0 24 24">
|
||||
<title>Auto light/dark mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-shadow">
|
||||
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
|
||||
<circle cx="12" cy="12" r="9" />
|
||||
<path d="M13 12h5" />
|
||||
<path d="M13 15h4" />
|
||||
<path d="M13 18h1" />
|
||||
<path d="M13 9h4" />
|
||||
<path d="M13 6h1" />
|
||||
</svg>
|
||||
</symbol>
|
||||
</svg>
|
||||
|
||||
<input type="checkbox" class="sidebar-toggle" name="__navigation" id="__navigation">
|
||||
<input type="checkbox" class="sidebar-toggle" name="__toc" id="__toc">
|
||||
<label class="overlay sidebar-overlay" for="__navigation">
|
||||
<div class="visually-hidden">Hide navigation sidebar</div>
|
||||
</label>
|
||||
<label class="overlay toc-overlay" for="__toc">
|
||||
<div class="visually-hidden">Hide table of contents sidebar</div>
|
||||
</label>
|
||||
|
||||
|
||||
|
||||
<div class="page">
|
||||
<header class="mobile-header">
|
||||
<div class="header-left">
|
||||
<label class="nav-overlay-icon" for="__navigation">
|
||||
<div class="visually-hidden">Toggle site navigation sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-menu"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
<div class="header-center">
|
||||
<a href="../../../index.html"><div class="brand">CUTLASS Python</div></a>
|
||||
</div>
|
||||
<div class="header-right">
|
||||
<div class="theme-toggle-container theme-toggle-header">
|
||||
<button class="theme-toggle">
|
||||
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
|
||||
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
|
||||
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
|
||||
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
|
||||
</button>
|
||||
</div>
|
||||
<label class="toc-overlay-icon toc-header-icon no-toc" for="__toc">
|
||||
<div class="visually-hidden">Toggle table of contents sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
</header>
|
||||
<aside class="sidebar-drawer">
|
||||
<div class="sidebar-container">
|
||||
|
||||
<div class="sidebar-sticky"><a class="sidebar-brand" href="../../../index.html">
|
||||
|
||||
<div class="sidebar-logo-container">
|
||||
<img class="sidebar-logo only-light" src="../../../_static/cutlass-logo-small.png" alt="Light Logo"/>
|
||||
<img class="sidebar-logo only-dark" src="../../../_static/cutlass-logo-small.png" alt="Dark Logo"/>
|
||||
</div>
|
||||
|
||||
<span class="sidebar-brand-text">CUTLASS Python</span>
|
||||
|
||||
</a><form class="sidebar-search-container" method="get" action="../../../search.html" role="search">
|
||||
<input class="sidebar-search" placeholder="Search" name="q" aria-label="Search">
|
||||
<input type="hidden" name="check_keywords" value="yes">
|
||||
<input type="hidden" name="area" value="default">
|
||||
</form>
|
||||
<div id="searchbox"></div><div class="sidebar-scroll"><div class="sidebar-tree">
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../index.html">Home</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../install.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../externals/00_basic_gemm.html">Getting Started</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contribute.html">Contributing</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python Documentation:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../modules.html">CUTLASS Python API</a><input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" role="switch" type="checkbox"/><label for="toctree-checkbox-1"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l2 has-children"><a class="reference internal" href="../../../cutlass.html">CUTLASS</a><input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" role="switch" type="checkbox"/><label for="toctree-checkbox-2"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../../../cutlass.emit.html">Emitters</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../../../cutlass.op.html">Operations</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../../../cutlass.utils.html">Utilities</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Examples and Tutorials:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../examples.html">Examples</a><input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" role="switch" type="checkbox"/><label for="toctree-checkbox-3"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../../../externals/00_basic_gemm.html">Basic GEMM</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../../../externals/01_epilogue.html">Epilogue</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../../../externals/02_pytorch_extension_grouped_gemm.html">PyTorch Extension</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Reference:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/cutlass">Github</a></li>
|
||||
</ul>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</aside>
|
||||
<div class="main">
|
||||
<div class="content">
|
||||
<div class="article-container">
|
||||
<a href="#" class="back-to-top muted-link">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24">
|
||||
<path d="M13 20h-2V8l-5.5 5.5-1.42-1.42L12 4.16l7.92 7.92-1.42 1.42L13 8v12z"></path>
|
||||
</svg>
|
||||
<span>Back to top</span>
|
||||
</a>
|
||||
<div class="content-icon-container">
|
||||
<div class="theme-toggle-container theme-toggle-content">
|
||||
<button class="theme-toggle">
|
||||
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
|
||||
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
|
||||
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
|
||||
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
|
||||
</button>
|
||||
</div>
|
||||
<label class="toc-overlay-icon toc-content-icon no-toc" for="__toc">
|
||||
<div class="visually-hidden">Toggle table of contents sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
<article role="main">
|
||||
<h1>Source code for cutlass.utils.check</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#################################################################################################</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.</span>
|
||||
<span class="c1"># SPDX-License-Identifier: BSD-3-Clause</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Redistribution and use in source and binary forms, with or without</span>
|
||||
<span class="c1"># modification, are permitted provided that the following conditions are met:</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># 1. Redistributions of source code must retain the above copyright notice, this</span>
|
||||
<span class="c1"># list of conditions and the following disclaimer.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># 2. Redistributions in binary form must reproduce the above copyright notice,</span>
|
||||
<span class="c1"># this list of conditions and the following disclaimer in the documentation</span>
|
||||
<span class="c1"># and/or other materials provided with the distribution.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># 3. Neither the name of the copyright holder nor the names of its</span>
|
||||
<span class="c1"># contributors may be used to endorse or promote products derived from</span>
|
||||
<span class="c1"># this software without specific prior written permission.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"</span>
|
||||
<span class="c1"># AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE</span>
|
||||
<span class="c1"># IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE</span>
|
||||
<span class="c1"># DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE</span>
|
||||
<span class="c1"># FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL</span>
|
||||
<span class="c1"># DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR</span>
|
||||
<span class="c1"># SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER</span>
|
||||
<span class="c1"># CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,</span>
|
||||
<span class="c1"># OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE</span>
|
||||
<span class="c1"># OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1">#################################################################################################</span>
|
||||
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd">Utility functions for checking constraints on kernels and calculating kernel attributes</span>
|
||||
<span class="sd">"""</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">ctypes</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">cutlass_bindings</span>
|
||||
<span class="kn">import</span> <span class="nn">cutlass</span>
|
||||
<span class="kn">from</span> <span class="nn">cutlass.backend.library</span> <span class="kn">import</span> <span class="n">DataTypeSize</span><span class="p">,</span> <span class="n">TileDescription</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="calculate_smem_usage_per_stage"><a class="viewcode-back" href="../../../cutlass.utils.html#cutlass.utils.check.calculate_smem_usage_per_stage">[docs]</a><span class="k">def</span> <span class="nf">calculate_smem_usage_per_stage</span><span class="p">(</span><span class="n">tile_description</span><span class="p">,</span> <span class="n">operation_kind</span><span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Returns the amount of shared memory in bytes consumed in a single stage of a kernel.</span>
|
||||
|
||||
<span class="sd"> :return: number of bytes of shared memory consumed by a single stage</span>
|
||||
<span class="sd"> :rtype: int</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">k</span> <span class="o">=</span> <span class="n">tile_description</span><span class="o">.</span><span class="n">threadblock_shape</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">operation_kind</span> <span class="o">==</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">OperationKind</span><span class="o">.</span><span class="n">Gemm</span><span class="p">:</span>
|
||||
<span class="n">stage_barrier_bytes</span> <span class="o">=</span> <span class="mi">32</span>
|
||||
<span class="k">return</span> <span class="p">(</span>
|
||||
<span class="p">(</span><span class="n">DataTypeSize</span><span class="p">[</span><span class="n">tile_description</span><span class="o">.</span><span class="n">math_instruction</span><span class="o">.</span><span class="n">element_a</span><span class="p">]</span> <span class="o">*</span> <span class="n">m</span> <span class="o">*</span> <span class="n">k</span> <span class="o">//</span> <span class="mi">8</span><span class="p">)</span>
|
||||
<span class="o">+</span> <span class="p">(</span><span class="n">DataTypeSize</span><span class="p">[</span><span class="n">tile_description</span><span class="o">.</span><span class="n">math_instruction</span><span class="o">.</span><span class="n">element_b</span><span class="p">]</span> <span class="o">*</span> <span class="n">k</span> <span class="o">*</span> <span class="n">n</span> <span class="o">//</span> <span class="mi">8</span><span class="p">)</span>
|
||||
<span class="o">+</span> <span class="n">stage_barrier_bytes</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="sa">f</span><span class="s2">"No available shared memory calculation for operation kind </span><span class="si">{</span><span class="n">operation</span><span class="o">.</span><span class="n">operation_kind</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="calculate_smem_usage"><a class="viewcode-back" href="../../../cutlass.utils.html#cutlass.utils.check.calculate_smem_usage">[docs]</a><span class="k">def</span> <span class="nf">calculate_smem_usage</span><span class="p">(</span><span class="n">operation</span><span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Returns the amount of shared memory in bytes consumed by a kernel.</span>
|
||||
|
||||
<span class="sd"> :return: number of bytes of shared memory consumed by the operation</span>
|
||||
<span class="sd"> :return: int</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="n">_per_stage</span> <span class="o">=</span> <span class="n">calculate_smem_usage_per_stage</span><span class="p">(</span><span class="n">operation</span><span class="o">.</span><span class="n">tile_description</span><span class="p">,</span> <span class="n">operation</span><span class="o">.</span><span class="n">operation_kind</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">_per_stage</span> <span class="o">*</span> <span class="n">operation</span><span class="o">.</span><span class="n">tile_description</span><span class="o">.</span><span class="n">stages</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="valid_stage_count"><a class="viewcode-back" href="../../../cutlass.utils.html#cutlass.utils.check.valid_stage_count">[docs]</a><span class="k">def</span> <span class="nf">valid_stage_count</span><span class="p">(</span><span class="n">cc</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">td</span><span class="p">:</span> <span class="n">TileDescription</span><span class="p">)</span> <span class="o">-></span> <span class="nb">tuple</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Checks whether a device with `cc` supports the number of stages within `tile_description`, both</span>
|
||||
<span class="sd"> based on raw limits on the number of stages and based on shared memory capacity</span>
|
||||
|
||||
<span class="sd"> :param cc: compute capability of device in question</span>
|
||||
<span class="sd"> :type cc: int</span>
|
||||
<span class="sd"> :param td: tile description to check</span>
|
||||
<span class="sd"> :type td: TileDescription</span>
|
||||
|
||||
<span class="sd"> :return: tuple with the first element indicating whether the provided tile description is</span>
|
||||
<span class="sd"> valid for the provided device and the second element being an error message</span>
|
||||
<span class="sd"> :rtype: tuple</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">if</span> <span class="n">cc</span> <span class="o">==</span> <span class="mi">90</span> <span class="ow">and</span> <span class="p">(</span><span class="n">td</span><span class="o">.</span><span class="n">stages</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="n">td</span><span class="o">.</span><span class="n">stages</span> <span class="o">==</span> <span class="mi">0</span><span class="p">):</span>
|
||||
<span class="c1"># Stage count of None or 0 for SM90 indicates that the CollectiveBuilder automatically</span>
|
||||
<span class="c1"># determines the stage count to use. Thus, all settings are valid in these scenarios.</span>
|
||||
<span class="k">return</span> <span class="p">(</span><span class="kc">True</span><span class="p">,</span> <span class="s2">""</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">td</span><span class="o">.</span><span class="n">stages</span> <span class="o"><=</span> <span class="mi">0</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="p">(</span><span class="kc">False</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"Stage counts must be positive integers. Tile description has stage count of </span><span class="si">{</span><span class="n">td</span><span class="o">.</span><span class="n">stages</span><span class="si">}</span><span class="s2">."</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">cc</span> <span class="o"><</span> <span class="mi">80</span> <span class="ow">and</span> <span class="n">td</span><span class="o">.</span><span class="n">stages</span> <span class="o">!=</span> <span class="mi">2</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="p">(</span><span class="kc">False</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"Tile description has stage count of </span><span class="si">{</span><span class="n">td</span><span class="o">.</span><span class="n">stages</span><span class="si">}</span><span class="s2">, "</span>
|
||||
<span class="sa">f</span><span class="s2">"but only 2 stages are supported on SM</span><span class="si">{</span><span class="n">cc</span><span class="si">}</span><span class="s2">."</span><span class="p">)</span>
|
||||
|
||||
<span class="n">smem_per_stage</span> <span class="o">=</span> <span class="n">calculate_smem_usage_per_stage</span><span class="p">(</span><span class="n">td</span><span class="p">,</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">OperationKind</span><span class="o">.</span><span class="n">Gemm</span><span class="p">)</span>
|
||||
<span class="n">smem_arch</span> <span class="o">=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">SharedMemPerCC</span><span class="p">[</span><span class="n">cc</span><span class="p">]</span> <span class="o"><<</span> <span class="mi">10</span>
|
||||
<span class="k">if</span> <span class="p">(</span><span class="n">smem_per_stage</span> <span class="o">*</span> <span class="n">td</span><span class="o">.</span><span class="n">stages</span><span class="p">)</span> <span class="o">></span> <span class="n">smem_arch</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="p">(</span> <span class="kc">False</span><span class="p">,</span>
|
||||
<span class="s2">"Configuration uses too much shared memory. Consider reducing stage count or tile shape.</span><span class="se">\n</span><span class="s2">"</span>
|
||||
<span class="sa">f</span><span class="s2">"Details: configuration uses </span><span class="si">{</span><span class="n">smem_per_stage</span><span class="si">}</span><span class="s2"> bytes of shared memory per stage, and "</span>
|
||||
<span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">td</span><span class="o">.</span><span class="n">stages</span><span class="si">}</span><span class="s2"> stages for a total of </span><span class="si">{</span><span class="n">smem_per_stage</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">td</span><span class="o">.</span><span class="n">stages</span><span class="si">}</span><span class="s2"> bytes.</span><span class="se">\n</span><span class="s2">"</span>
|
||||
<span class="sa">f</span><span class="s2">"The maxmium amoung of shared memory that can be used per block on CC </span><span class="si">{</span><span class="n">cc</span><span class="si">}</span><span class="s2"> is </span><span class="si">{</span><span class="n">smem_arch</span><span class="si">}</span><span class="s2">."</span><span class="p">)</span>
|
||||
|
||||
<span class="k">return</span> <span class="p">(</span><span class="kc">True</span><span class="p">,</span> <span class="s2">""</span><span class="p">)</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="valid_cluster_shape"><a class="viewcode-back" href="../../../cutlass.utils.html#cutlass.utils.check.valid_cluster_shape">[docs]</a><span class="k">def</span> <span class="nf">valid_cluster_shape</span><span class="p">(</span><span class="n">cc</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">cluster_shape</span><span class="p">:</span> <span class="nb">list</span><span class="p">)</span> <span class="o">-></span> <span class="nb">tuple</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Checks whether a device with `cc` supports a thread block cluster of shape `cluster_shape`.</span>
|
||||
|
||||
<span class="sd"> :param cc: compute capability of device in question</span>
|
||||
<span class="sd"> :type cc: int</span>
|
||||
<span class="sd"> :param cluster_shape: dimensions of thread block cluster shape to check</span>
|
||||
<span class="sd"> :type cluster_shape: list</span>
|
||||
|
||||
<span class="sd"> :return: tuple with the first element indicating whether the provided cluster shape is</span>
|
||||
<span class="sd"> valid for the provided device and the second element being an error message</span>
|
||||
<span class="sd"> :rtype: tuple</span>
|
||||
<span class="sd"> """</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">cc</span> <span class="o"><</span> <span class="mi">90</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">cluster_shape</span> <span class="o">!=</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">]:</span>
|
||||
<span class="k">return</span> <span class="p">(</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="sa">f</span><span class="s2">"Cluster shape for pre-SM90 architectures must be [1, 1, 1]. Received cluster shape of "</span>
|
||||
<span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">cluster_shape</span><span class="si">}</span><span class="s2"> for SM</span><span class="si">{</span><span class="n">cc</span><span class="si">}</span><span class="s2">."</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="p">(</span><span class="kc">True</span><span class="p">,</span> <span class="s2">""</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">cluster_shape</span><span class="p">)</span> <span class="o">!=</span> <span class="mi">3</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="p">(</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="sa">f</span><span class="s2">"Cluster shapes must be rank-3. Received </span><span class="si">{</span><span class="n">cluster_shape</span><span class="si">}</span><span class="s2"> (rank </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">cluster_shape</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">cluster_shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">!=</span> <span class="mi">1</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="p">(</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="s2">"CUTLASS kernels currently require the third dimension of cluster shape to be 1. "</span>
|
||||
<span class="sa">f</span><span class="s2">"Received cluster shape of </span><span class="si">{</span><span class="n">cluster_shape</span><span class="si">}</span><span class="s2">."</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># The CUDA programming guide currently defines a maximum of 8 thread blocks per cluster</span>
|
||||
<span class="c1"># as being portably supported (https://docs.nvidia.com/cuda/cuda-c-programming-guide/#thread-block-clusters).</span>
|
||||
<span class="c1"># Current CUTLASS kernels only have non-unit cluster dimensions within the first two dimensions,</span>
|
||||
<span class="c1"># so we check that the first two dimensions of the cluster shape do not exceed 8 thread blocks in total.</span>
|
||||
<span class="n">blocks_in_2d</span> <span class="o">=</span> <span class="n">cluster_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">cluster_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
|
||||
<span class="k">if</span> <span class="n">blocks_in_2d</span> <span class="o">></span> <span class="mi">8</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="p">(</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="sa">f</span><span class="s2">"Thread block clusters with more than 8 thread blocks are currently unsupported on SM</span><span class="si">{</span><span class="n">cc</span><span class="si">}</span><span class="s2">. "</span>
|
||||
<span class="sa">f</span><span class="s2">"Received cluster shape </span><span class="si">{</span><span class="n">cluster_shape</span><span class="si">}</span><span class="s2">, which has </span><span class="si">{</span><span class="n">blocks_in_2d</span><span class="si">}</span><span class="s2"> thread blocks."</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="p">(</span><span class="kc">True</span><span class="p">,</span> <span class="s2">""</span><span class="p">)</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="valid_kernel_schedule"><a class="viewcode-back" href="../../../cutlass.utils.html#cutlass.utils.check.valid_kernel_schedule">[docs]</a><span class="k">def</span> <span class="nf">valid_kernel_schedule</span><span class="p">(</span><span class="n">cc</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">kernel_schedule</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">KernelScheduleType</span><span class="p">)</span> <span class="o">-></span> <span class="nb">tuple</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Checks whether a device with ``cc`` supports ``kernel_schedule``.</span>
|
||||
|
||||
<span class="sd"> :param cc: compute capability of device in question</span>
|
||||
<span class="sd"> :type cc: int</span>
|
||||
<span class="sd"> :param kernel_schedule: kernel schedule type</span>
|
||||
<span class="sd"> :type KernelScheduleType: cutlass.KernelScheduleType</span>
|
||||
|
||||
<span class="sd"> :return: tuple with the first element indicating whether the provided kernel schedule is</span>
|
||||
<span class="sd"> valid for the provided device and the second element being an error message</span>
|
||||
<span class="sd"> :rtype: tuple</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">if</span> <span class="n">kernel_schedule</span> <span class="o">!=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">KernelScheduleType</span><span class="o">.</span><span class="n">ScheduleAuto</span> <span class="ow">and</span> <span class="n">cc</span> <span class="o"><</span> <span class="mi">90</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="p">(</span><span class="kc">False</span><span class="p">,</span> <span class="s2">"Non-default kernel schedules are only supported on SM90 and beyond"</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="p">(</span><span class="kc">True</span><span class="p">,</span> <span class="s2">""</span><span class="p">)</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="alignment_or_default"><a class="viewcode-back" href="../../../cutlass.utils.html#cutlass.utils.check.alignment_or_default">[docs]</a><span class="k">def</span> <span class="nf">alignment_or_default</span><span class="p">(</span><span class="n">alignment_provided</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">default_alignment</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Returns `alignment_provided` if it is set, otherwise `default_alignment` and checks</span>
|
||||
<span class="sd"> that `alignment_provided` does not exceed `default_alignment`.</span>
|
||||
|
||||
<span class="sd"> :param alignment_provided: alignment preference specified. Can be None.</span>
|
||||
<span class="sd"> :type alignment_provided: int</span>
|
||||
<span class="sd"> :param default_alignment: alignment to use if `alignment_provided` is None</span>
|
||||
<span class="sd"> :type default_alignment: int</span>
|
||||
|
||||
<span class="sd"> :return: alignment to use</span>
|
||||
<span class="sd"> :rtype: int</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">if</span> <span class="n">alignment_provided</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">alignment_provided</span> <span class="o">></span> <span class="n">default_alignment</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Alignment </span><span class="si">{</span><span class="n">alignment_provided</span><span class="si">}</span><span class="s2"> exceeds the maximum supported of </span><span class="si">{</span><span class="n">default_alignment</span><span class="si">}</span><span class="s2">."</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">alignment_provided</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">default_alignment</span></div>
|
||||
</pre></div>
|
||||
</article>
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
<div class="related-pages">
|
||||
|
||||
|
||||
</div>
|
||||
<div class="bottom-of-page">
|
||||
<div class="left-details">
|
||||
<div class="copyright">
|
||||
Copyright © 2023, NVIDIA
|
||||
</div>
|
||||
Made with <a href="https://www.sphinx-doc.org/">Sphinx</a> and <a class="muted-link" href="https://pradyunsg.me">@pradyunsg</a>'s
|
||||
|
||||
<a href="https://github.com/pradyunsg/furo">Furo</a>
|
||||
|
||||
</div>
|
||||
<div class="right-details">
|
||||
<div class="icons">
|
||||
<a class="muted-link " href="https://github.com/NVIDIA/cutlass" aria-label="GitHub">
|
||||
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 16 16">
|
||||
<path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z"></path>
|
||||
</svg>
|
||||
</a>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</footer>
|
||||
</div>
|
||||
<aside class="toc-drawer no-toc">
|
||||
|
||||
|
||||
|
||||
</aside>
|
||||
</div>
|
||||
</div><script data-url_root="../../../" id="documentation_options" src="../../../_static/documentation_options.js"></script>
|
||||
<script src="../../../_static/doctools.js"></script>
|
||||
<script src="../../../_static/sphinx_highlight.js"></script>
|
||||
<script src="../../../_static/scripts/furo.js"></script>
|
||||
<script src="../../../_static/clipboard.min.js"></script>
|
||||
<script src="../../../_static/copybutton.js"></script>
|
||||
<script src="../../../_static/tabs.js"></script>
|
||||
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
623
python/docs/_modules/cutlass/utils/datatypes.html
Normal file
623
python/docs/_modules/cutlass/utils/datatypes.html
Normal file
@ -0,0 +1,623 @@
|
||||
<!doctype html>
|
||||
<html class="no-js" lang="en">
|
||||
<head><meta charset="utf-8"/>
|
||||
<meta name="viewport" content="width=device-width,initial-scale=1"/>
|
||||
<meta name="color-scheme" content="light dark"><link rel="index" title="Index" href="../../../genindex.html" /><link rel="search" title="Search" href="../../../search.html" />
|
||||
<link rel="canonical" href="docs/_modules/cutlass/utils/datatypes.html" />
|
||||
|
||||
<!-- Generated with Sphinx 6.1.3 and Furo 2023.03.27 -->
|
||||
<title>cutlass.utils.datatypes - CUTLASS Python</title>
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/pygments.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/styles/furo.css?digest=fad236701ea90a88636c2a8c73b44ae642ed2a53" />
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/copybutton.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/tabs.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../../../_static/styles/furo-extensions.css?digest=30d1aed668e5c3a91c3e3bf6a60b675221979f0e" />
|
||||
|
||||
|
||||
|
||||
|
||||
<style>
|
||||
body {
|
||||
--color-code-background: #eeffcc;
|
||||
--color-code-foreground: black;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
@media not print {
|
||||
body[data-theme="dark"] {
|
||||
--color-code-background: #272822;
|
||||
--color-code-foreground: #f8f8f2;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
@media (prefers-color-scheme: dark) {
|
||||
body:not([data-theme="light"]) {
|
||||
--color-code-background: #272822;
|
||||
--color-code-foreground: #f8f8f2;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
</style></head>
|
||||
<body>
|
||||
|
||||
<script>
|
||||
document.body.dataset.theme = localStorage.getItem("theme") || "auto";
|
||||
</script>
|
||||
|
||||
|
||||
<svg xmlns="http://www.w3.org/2000/svg" style="display: none;">
|
||||
<symbol id="svg-toc" viewBox="0 0 24 24">
|
||||
<title>Contents</title>
|
||||
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 1024 1024">
|
||||
<path d="M408 442h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8zm-8 204c0 4.4 3.6 8 8 8h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56zm504-486H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zm0 632H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zM115.4 518.9L271.7 642c5.8 4.6 14.4.5 14.4-6.9V388.9c0-7.4-8.5-11.5-14.4-6.9L115.4 505.1a8.74 8.74 0 0 0 0 13.8z"/>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-menu" viewBox="0 0 24 24">
|
||||
<title>Menu</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-menu">
|
||||
<line x1="3" y1="12" x2="21" y2="12"></line>
|
||||
<line x1="3" y1="6" x2="21" y2="6"></line>
|
||||
<line x1="3" y1="18" x2="21" y2="18"></line>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-arrow-right" viewBox="0 0 24 24">
|
||||
<title>Expand</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-chevron-right">
|
||||
<polyline points="9 18 15 12 9 6"></polyline>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-sun" viewBox="0 0 24 24">
|
||||
<title>Light mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather-sun">
|
||||
<circle cx="12" cy="12" r="5"></circle>
|
||||
<line x1="12" y1="1" x2="12" y2="3"></line>
|
||||
<line x1="12" y1="21" x2="12" y2="23"></line>
|
||||
<line x1="4.22" y1="4.22" x2="5.64" y2="5.64"></line>
|
||||
<line x1="18.36" y1="18.36" x2="19.78" y2="19.78"></line>
|
||||
<line x1="1" y1="12" x2="3" y2="12"></line>
|
||||
<line x1="21" y1="12" x2="23" y2="12"></line>
|
||||
<line x1="4.22" y1="19.78" x2="5.64" y2="18.36"></line>
|
||||
<line x1="18.36" y1="5.64" x2="19.78" y2="4.22"></line>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-moon" viewBox="0 0 24 24">
|
||||
<title>Dark mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-moon">
|
||||
<path stroke="none" d="M0 0h24v24H0z" fill="none" />
|
||||
<path d="M12 3c.132 0 .263 0 .393 0a7.5 7.5 0 0 0 7.92 12.446a9 9 0 1 1 -8.313 -12.454z" />
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-sun-half" viewBox="0 0 24 24">
|
||||
<title>Auto light/dark mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-shadow">
|
||||
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
|
||||
<circle cx="12" cy="12" r="9" />
|
||||
<path d="M13 12h5" />
|
||||
<path d="M13 15h4" />
|
||||
<path d="M13 18h1" />
|
||||
<path d="M13 9h4" />
|
||||
<path d="M13 6h1" />
|
||||
</svg>
|
||||
</symbol>
|
||||
</svg>
|
||||
|
||||
<input type="checkbox" class="sidebar-toggle" name="__navigation" id="__navigation">
|
||||
<input type="checkbox" class="sidebar-toggle" name="__toc" id="__toc">
|
||||
<label class="overlay sidebar-overlay" for="__navigation">
|
||||
<div class="visually-hidden">Hide navigation sidebar</div>
|
||||
</label>
|
||||
<label class="overlay toc-overlay" for="__toc">
|
||||
<div class="visually-hidden">Hide table of contents sidebar</div>
|
||||
</label>
|
||||
|
||||
|
||||
|
||||
<div class="page">
|
||||
<header class="mobile-header">
|
||||
<div class="header-left">
|
||||
<label class="nav-overlay-icon" for="__navigation">
|
||||
<div class="visually-hidden">Toggle site navigation sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-menu"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
<div class="header-center">
|
||||
<a href="../../../index.html"><div class="brand">CUTLASS Python</div></a>
|
||||
</div>
|
||||
<div class="header-right">
|
||||
<div class="theme-toggle-container theme-toggle-header">
|
||||
<button class="theme-toggle">
|
||||
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
|
||||
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
|
||||
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
|
||||
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
|
||||
</button>
|
||||
</div>
|
||||
<label class="toc-overlay-icon toc-header-icon no-toc" for="__toc">
|
||||
<div class="visually-hidden">Toggle table of contents sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
</header>
|
||||
<aside class="sidebar-drawer">
|
||||
<div class="sidebar-container">
|
||||
|
||||
<div class="sidebar-sticky"><a class="sidebar-brand" href="../../../index.html">
|
||||
|
||||
<div class="sidebar-logo-container">
|
||||
<img class="sidebar-logo only-light" src="../../../_static/cutlass-logo-small.png" alt="Light Logo"/>
|
||||
<img class="sidebar-logo only-dark" src="../../../_static/cutlass-logo-small.png" alt="Dark Logo"/>
|
||||
</div>
|
||||
|
||||
<span class="sidebar-brand-text">CUTLASS Python</span>
|
||||
|
||||
</a><form class="sidebar-search-container" method="get" action="../../../search.html" role="search">
|
||||
<input class="sidebar-search" placeholder="Search" name="q" aria-label="Search">
|
||||
<input type="hidden" name="check_keywords" value="yes">
|
||||
<input type="hidden" name="area" value="default">
|
||||
</form>
|
||||
<div id="searchbox"></div><div class="sidebar-scroll"><div class="sidebar-tree">
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../index.html">Home</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../install.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../externals/00_basic_gemm.html">Getting Started</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contribute.html">Contributing</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python Documentation:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../modules.html">CUTLASS Python API</a><input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" role="switch" type="checkbox"/><label for="toctree-checkbox-1"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l2 has-children"><a class="reference internal" href="../../../cutlass.html">CUTLASS</a><input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" role="switch" type="checkbox"/><label for="toctree-checkbox-2"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../../../cutlass.emit.html">Emitters</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../../../cutlass.op.html">Operations</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../../../cutlass.utils.html">Utilities</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Examples and Tutorials:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../examples.html">Examples</a><input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" role="switch" type="checkbox"/><label for="toctree-checkbox-3"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../../../externals/00_basic_gemm.html">Basic GEMM</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../../../externals/01_epilogue.html">Epilogue</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../../../externals/02_pytorch_extension_grouped_gemm.html">PyTorch Extension</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Reference:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/cutlass">Github</a></li>
|
||||
</ul>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</aside>
|
||||
<div class="main">
|
||||
<div class="content">
|
||||
<div class="article-container">
|
||||
<a href="#" class="back-to-top muted-link">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24">
|
||||
<path d="M13 20h-2V8l-5.5 5.5-1.42-1.42L12 4.16l7.92 7.92-1.42 1.42L13 8v12z"></path>
|
||||
</svg>
|
||||
<span>Back to top</span>
|
||||
</a>
|
||||
<div class="content-icon-container">
|
||||
<div class="theme-toggle-container theme-toggle-content">
|
||||
<button class="theme-toggle">
|
||||
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
|
||||
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
|
||||
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
|
||||
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
|
||||
</button>
|
||||
</div>
|
||||
<label class="toc-overlay-icon toc-content-icon no-toc" for="__toc">
|
||||
<div class="visually-hidden">Toggle table of contents sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
<article role="main">
|
||||
<h1>Source code for cutlass.utils.datatypes</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#################################################################################################</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.</span>
|
||||
<span class="c1"># SPDX-License-Identifier: BSD-3-Clause</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Redistribution and use in source and binary forms, with or without</span>
|
||||
<span class="c1"># modification, are permitted provided that the following conditions are met:</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># 1. Redistributions of source code must retain the above copyright notice, this</span>
|
||||
<span class="c1"># list of conditions and the following disclaimer.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># 2. Redistributions in binary form must reproduce the above copyright notice,</span>
|
||||
<span class="c1"># this list of conditions and the following disclaimer in the documentation</span>
|
||||
<span class="c1"># and/or other materials provided with the distribution.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># 3. Neither the name of the copyright holder nor the names of its</span>
|
||||
<span class="c1"># contributors may be used to endorse or promote products derived from</span>
|
||||
<span class="c1"># this software without specific prior written permission.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"</span>
|
||||
<span class="c1"># AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE</span>
|
||||
<span class="c1"># IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE</span>
|
||||
<span class="c1"># DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE</span>
|
||||
<span class="c1"># FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL</span>
|
||||
<span class="c1"># DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR</span>
|
||||
<span class="c1"># SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER</span>
|
||||
<span class="c1"># CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,</span>
|
||||
<span class="c1"># OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE</span>
|
||||
<span class="c1"># OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1">#################################################################################################</span>
|
||||
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd">Utility functions for converting between frontend datatypes and CUTLASS datatypes</span>
|
||||
<span class="sd">"""</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">cutlass_bindings</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">cutlass</span>
|
||||
<span class="kn">from</span> <span class="nn">cutlass.backend.library</span> <span class="kn">import</span> <span class="p">(</span>
|
||||
<span class="n">DataTypeSize</span><span class="p">,</span>
|
||||
<span class="n">MathInstruction</span><span class="p">,</span>
|
||||
<span class="n">MathOperation</span><span class="p">,</span>
|
||||
<span class="n">ShortLayoutTypeNames</span><span class="p">,</span>
|
||||
<span class="n">TileDescription</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="k">try</span><span class="p">:</span>
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
|
||||
<span class="n">numpy_available</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="n">_library_to_numpy_dict</span> <span class="o">=</span> <span class="p">{</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f16</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f32</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f64</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">float64</span><span class="p">,</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">s8</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">,</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">s32</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||||
<span class="p">}</span>
|
||||
<span class="k">except</span> <span class="ne">ImportError</span><span class="p">:</span>
|
||||
<span class="n">numpy_available</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
<span class="n">_library_to_numpy_dict</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="numpy_library_type"><a class="viewcode-back" href="../../../cutlass.utils.html#cutlass.utils.datatypes.numpy_library_type">[docs]</a><span class="k">def</span> <span class="nf">numpy_library_type</span><span class="p">(</span><span class="n">inp</span><span class="p">)</span> <span class="o">-></span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">numpy_available</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">inp</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f16</span>
|
||||
<span class="k">elif</span> <span class="n">inp</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f32</span>
|
||||
<span class="k">elif</span> <span class="n">inp</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">float64</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f64</span>
|
||||
<span class="k">elif</span> <span class="n">inp</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">s8</span>
|
||||
<span class="k">elif</span> <span class="n">inp</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">s32</span>
|
||||
<span class="k">return</span> <span class="kc">None</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="numpy_type"><a class="viewcode-back" href="../../../cutlass.utils.html#cutlass.utils.datatypes.numpy_type">[docs]</a><span class="k">def</span> <span class="nf">numpy_type</span><span class="p">(</span><span class="n">inp</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">_library_to_numpy_dict</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">inp</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span></div>
|
||||
|
||||
|
||||
<span class="k">try</span><span class="p">:</span>
|
||||
<span class="kn">import</span> <span class="nn">cupy</span> <span class="k">as</span> <span class="nn">cp</span>
|
||||
|
||||
<span class="n">cupy_available</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="n">_library_to_cupy_dict</span> <span class="o">=</span> <span class="p">{</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f16</span><span class="p">:</span> <span class="n">cp</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f32</span><span class="p">:</span> <span class="n">cp</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f64</span><span class="p">:</span> <span class="n">cp</span><span class="o">.</span><span class="n">float64</span><span class="p">,</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">s8</span><span class="p">:</span> <span class="n">cp</span><span class="o">.</span><span class="n">int8</span><span class="p">,</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">s32</span><span class="p">:</span> <span class="n">cp</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||||
<span class="p">}</span>
|
||||
<span class="k">except</span> <span class="ne">ImportError</span><span class="p">:</span>
|
||||
<span class="n">cupy_available</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
<span class="n">_library_to_cupy_dict</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="cupy_library_type"><a class="viewcode-back" href="../../../cutlass.utils.html#cutlass.utils.datatypes.cupy_library_type">[docs]</a><span class="k">def</span> <span class="nf">cupy_library_type</span><span class="p">(</span><span class="n">inp</span><span class="p">)</span> <span class="o">-></span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">cupy_available</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">inp</span> <span class="o">==</span> <span class="n">cp</span><span class="o">.</span><span class="n">float16</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f16</span>
|
||||
<span class="k">elif</span> <span class="n">inp</span> <span class="o">==</span> <span class="n">cp</span><span class="o">.</span><span class="n">float32</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f32</span>
|
||||
<span class="k">elif</span> <span class="n">inp</span> <span class="o">==</span> <span class="n">cp</span><span class="o">.</span><span class="n">float64</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f64</span>
|
||||
<span class="k">return</span> <span class="kc">None</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="cupy_type"><a class="viewcode-back" href="../../../cutlass.utils.html#cutlass.utils.datatypes.cupy_type">[docs]</a><span class="k">def</span> <span class="nf">cupy_type</span><span class="p">(</span><span class="n">inp</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">_library_to_cupy_dict</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">inp</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span></div>
|
||||
|
||||
|
||||
<span class="k">try</span><span class="p">:</span>
|
||||
<span class="kn">import</span> <span class="nn">torch</span>
|
||||
|
||||
<span class="n">torch_available</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="n">_torch_to_library_dict</span> <span class="o">=</span> <span class="p">{</span>
|
||||
<span class="n">torch</span><span class="o">.</span><span class="n">half</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f16</span><span class="p">,</span>
|
||||
<span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f16</span><span class="p">,</span>
|
||||
<span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f32</span><span class="p">,</span>
|
||||
<span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f32</span><span class="p">,</span>
|
||||
<span class="n">torch</span><span class="o">.</span><span class="n">double</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f64</span><span class="p">,</span>
|
||||
<span class="n">torch</span><span class="o">.</span><span class="n">float64</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f64</span><span class="p">,</span>
|
||||
<span class="p">}</span>
|
||||
|
||||
<span class="n">_library_to_torch_dict</span> <span class="o">=</span> <span class="p">{</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f16</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">half</span><span class="p">,</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f16</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f32</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">,</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f32</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f64</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">double</span><span class="p">,</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f64</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">float64</span><span class="p">,</span>
|
||||
<span class="p">}</span>
|
||||
<span class="k">except</span> <span class="ne">ImportError</span><span class="p">:</span>
|
||||
<span class="n">torch_available</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
<span class="n">_torch_to_library_dict</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
<span class="n">_library_to_torch_dict</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="torch_library_type"><a class="viewcode-back" href="../../../cutlass.utils.html#cutlass.utils.datatypes.torch_library_type">[docs]</a><span class="k">def</span> <span class="nf">torch_library_type</span><span class="p">(</span><span class="n">inp</span><span class="p">)</span> <span class="o">-></span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">_torch_to_library_dict</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">inp</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="torch_type"><a class="viewcode-back" href="../../../cutlass.utils.html#cutlass.utils.datatypes.torch_type">[docs]</a><span class="k">def</span> <span class="nf">torch_type</span><span class="p">(</span><span class="n">inp</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">_library_to_torch_dict</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">inp</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span></div>
|
||||
|
||||
|
||||
<span class="k">try</span><span class="p">:</span>
|
||||
<span class="kn">import</span> <span class="nn">bfloat16</span>
|
||||
|
||||
<span class="n">bfloat16_available</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="k">except</span> <span class="ne">ImportError</span><span class="p">:</span>
|
||||
<span class="n">bfloat16_available</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="bfloat16_library_type"><a class="viewcode-back" href="../../../cutlass.utils.html#cutlass.utils.datatypes.bfloat16_library_type">[docs]</a><span class="k">def</span> <span class="nf">bfloat16_library_type</span><span class="p">(</span><span class="n">inp</span><span class="p">)</span> <span class="o">-></span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">bfloat16_available</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">inp</span> <span class="o">==</span> <span class="n">bfloat16</span><span class="o">.</span><span class="n">bfloat16</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">bf16</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="bfloat16_type"><a class="viewcode-back" href="../../../cutlass.utils.html#cutlass.utils.datatypes.bfloat16_type">[docs]</a><span class="k">def</span> <span class="nf">bfloat16_type</span><span class="p">(</span><span class="n">inp</span><span class="p">)</span> <span class="o">-></span> <span class="n">bfloat16</span><span class="o">.</span><span class="n">bfloat16</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">bfloat16_available</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">inp</span> <span class="o">==</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">bf16</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">bfloat16</span><span class="o">.</span><span class="n">bfloat16</span></div>
|
||||
|
||||
|
||||
<span class="c1"># Mapping from library data type to Python-bound CUTLASS data type</span>
|
||||
<span class="n">library_to_binding_dict</span> <span class="o">=</span> <span class="p">{</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">s8</span><span class="p">:</span> <span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">int8</span><span class="p">,</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">s32</span><span class="p">:</span> <span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f16</span><span class="p">:</span> <span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">bf16</span><span class="p">:</span> <span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">bfloat16</span><span class="p">,</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f32</span><span class="p">:</span> <span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f64</span><span class="p">:</span> <span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">float64</span><span class="p">,</span>
|
||||
<span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">tf32</span><span class="p">:</span> <span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">tfloat32</span><span class="p">,</span>
|
||||
<span class="p">}</span>
|
||||
|
||||
<span class="c1"># Mapping from Python-bound CUTLASS data type to library data type</span>
|
||||
<span class="n">binding_to_library</span> <span class="o">=</span> <span class="p">{</span>
|
||||
<span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">int8</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">s8</span><span class="p">,</span>
|
||||
<span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">int32</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">s32</span><span class="p">,</span>
|
||||
<span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">float16</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f16</span><span class="p">,</span>
|
||||
<span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">bfloat16</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">bf16</span><span class="p">,</span>
|
||||
<span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">float32</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f32</span><span class="p">,</span>
|
||||
<span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">float64</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">f64</span><span class="p">,</span>
|
||||
<span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">tfloat32</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">tf32</span><span class="p">,</span>
|
||||
<span class="p">}</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="binding_library_type"><a class="viewcode-back" href="../../../cutlass.utils.html#cutlass.utils.datatypes.binding_library_type">[docs]</a><span class="k">def</span> <span class="nf">binding_library_type</span><span class="p">(</span><span class="n">inp</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="n">inp</span> <span class="ow">in</span> <span class="n">binding_to_library</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">binding_to_library</span><span class="p">[</span><span class="n">inp</span><span class="p">]</span>
|
||||
<span class="k">return</span> <span class="kc">None</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="has_binding_type"><a class="viewcode-back" href="../../../cutlass.utils.html#cutlass.utils.datatypes.has_binding_type">[docs]</a><span class="k">def</span> <span class="nf">has_binding_type</span><span class="p">(</span><span class="n">inp</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">inp</span> <span class="ow">in</span> <span class="n">library_to_binding_dict</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="library_to_binding"><a class="viewcode-back" href="../../../cutlass.utils.html#cutlass.utils.datatypes.library_to_binding">[docs]</a><span class="k">def</span> <span class="nf">library_to_binding</span><span class="p">(</span><span class="n">inp</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataType</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="n">has_binding_type</span><span class="p">(</span><span class="n">inp</span><span class="p">):</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="sa">f</span><span class="s2">"No available conversion from library type </span><span class="si">{</span><span class="n">inp</span><span class="si">}</span><span class="s2"> to Python-bound CUTLASS type"</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">library_to_binding_dict</span><span class="p">[</span><span class="n">inp</span><span class="p">]</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="library_type"><a class="viewcode-back" href="../../../cutlass.utils.html#cutlass.utils.datatypes.library_type">[docs]</a><span class="k">def</span> <span class="nf">library_type</span><span class="p">(</span><span class="n">inp</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="n">inp</span> <span class="ow">in</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">DataTypeSize</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
|
||||
<span class="k">return</span> <span class="n">inp</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">cvt_fn</span> <span class="ow">in</span> <span class="p">[</span>
|
||||
<span class="n">bfloat16_library_type</span><span class="p">,</span>
|
||||
<span class="n">cupy_library_type</span><span class="p">,</span>
|
||||
<span class="n">numpy_library_type</span><span class="p">,</span>
|
||||
<span class="n">torch_library_type</span><span class="p">,</span>
|
||||
<span class="n">binding_library_type</span><span class="p">,</span>
|
||||
<span class="p">]:</span>
|
||||
<span class="n">out</span> <span class="o">=</span> <span class="n">cvt_fn</span><span class="p">(</span><span class="n">inp</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">out</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">out</span>
|
||||
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="sa">f</span><span class="s2">"No available conversion from type </span><span class="si">{</span><span class="n">inp</span><span class="si">}</span><span class="s2"> to a library type."</span><span class="p">)</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="library_layout"><a class="viewcode-back" href="../../../cutlass.utils.html#cutlass.utils.datatypes.library_layout">[docs]</a><span class="k">def</span> <span class="nf">library_layout</span><span class="p">(</span><span class="n">layout</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="n">layout</span> <span class="ow">in</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutTag</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
|
||||
<span class="k">return</span> <span class="n">layout</span>
|
||||
|
||||
<span class="c1"># Convert Python-bound CUTLASS layout to profiler library layout</span>
|
||||
<span class="k">if</span> <span class="n">layout</span> <span class="o">==</span> <span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">RowMajor</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutType</span><span class="o">.</span><span class="n">RowMajor</span>
|
||||
<span class="k">elif</span> <span class="n">layout</span> <span class="o">==</span> <span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">ColumnMajor</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutType</span><span class="o">.</span><span class="n">ColumnMajor</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="sa">f</span><span class="s2">"No conversion available for layout </span><span class="si">{</span><span class="n">layout</span><span class="si">}</span><span class="s2"> to library layout."</span><span class="p">)</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="binding_type"><a class="viewcode-back" href="../../../cutlass.utils.html#cutlass.utils.datatypes.binding_type">[docs]</a><span class="k">def</span> <span class="nf">binding_type</span><span class="p">(</span><span class="n">inp</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="n">inp</span> <span class="ow">in</span> <span class="n">DataTypeSize</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
|
||||
<span class="k">return</span> <span class="n">inp</span>
|
||||
|
||||
<span class="n">libtype</span> <span class="o">=</span> <span class="n">library_type</span><span class="p">(</span><span class="n">inp</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">library_to_binding</span><span class="p">(</span><span class="n">libtype</span><span class="p">)</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="binding_layout"><a class="viewcode-back" href="../../../cutlass.utils.html#cutlass.utils.datatypes.binding_layout">[docs]</a><span class="k">def</span> <span class="nf">binding_layout</span><span class="p">(</span><span class="n">layout</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="n">layout</span> <span class="ow">in</span> <span class="n">ShortLayoutTypeNames</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
|
||||
<span class="k">return</span> <span class="n">layout</span>
|
||||
<span class="k">elif</span> <span class="n">layout</span> <span class="o">==</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutType</span><span class="o">.</span><span class="n">RowMajor</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">RowMajor</span>
|
||||
<span class="k">elif</span> <span class="n">layout</span> <span class="o">==</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutType</span><span class="o">.</span><span class="n">ColumnMajor</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">ColumnMajor</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="sa">f</span><span class="s2">"No conversion available for layout </span><span class="si">{</span><span class="n">layout</span><span class="si">}</span><span class="s2"> to Python-bound CUTLASS layout."</span><span class="p">)</span></div>
|
||||
|
||||
|
||||
<span class="k">def</span> <span class="nf">_tensor_from_numpy</span><span class="p">(</span><span class="n">np_tensor</span><span class="p">):</span>
|
||||
<span class="n">dtype</span> <span class="o">=</span> <span class="n">library_type</span><span class="p">(</span><span class="n">np_tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">np_tensor</span><span class="o">.</span><span class="n">flags</span><span class="o">.</span><span class="n">c_contiguous</span><span class="p">:</span>
|
||||
<span class="n">layout</span> <span class="o">=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutType</span><span class="o">.</span><span class="n">RowMajor</span>
|
||||
<span class="k">elif</span> <span class="n">np_tensor</span><span class="o">.</span><span class="n">flags</span><span class="o">.</span><span class="n">f_contiguous</span><span class="p">:</span>
|
||||
<span class="n">layout</span> <span class="o">=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutType</span><span class="o">.</span><span class="n">ColumnMajor</span>
|
||||
<span class="k">return</span> <span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="n">layout</span><span class="p">)</span>
|
||||
|
||||
|
||||
<span class="k">def</span> <span class="nf">_tensor_from_torch</span><span class="p">(</span><span class="n">pt_tensor</span><span class="p">):</span>
|
||||
<span class="n">dtype</span> <span class="o">=</span> <span class="n">library_type</span><span class="p">(</span><span class="n">pt_tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutType</span><span class="o">.</span><span class="n">RowMajor</span><span class="p">)</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="get_datatype_and_layout"><a class="viewcode-back" href="../../../cutlass.utils.html#cutlass.utils.datatypes.get_datatype_and_layout">[docs]</a><span class="k">def</span> <span class="nf">get_datatype_and_layout</span><span class="p">(</span><span class="n">tensor</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="p">(</span><span class="n">numpy_available</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">))</span> <span class="ow">or</span> <span class="p">(</span>
|
||||
<span class="n">cupy_available</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">cp</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span>
|
||||
<span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">_tensor_from_numpy</span><span class="p">(</span><span class="n">tensor</span><span class="p">)</span>
|
||||
<span class="k">elif</span> <span class="n">torch_available</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">_tensor_from_torch</span><span class="p">(</span><span class="n">tensor</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Unable to convert tensor of type </span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="n">tensor</span><span class="p">)</span><span class="si">}</span><span class="s2"> to Python-bound CUTLASS datatype and layout."</span><span class="p">)</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="binding_opclass"><a class="viewcode-back" href="../../../cutlass.utils.html#cutlass.utils.datatypes.binding_opclass">[docs]</a><span class="k">def</span> <span class="nf">binding_opclass</span><span class="p">(</span><span class="n">opclass</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">OpcodeClass</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="n">opclass</span> <span class="o">==</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">OpcodeClass</span><span class="o">.</span><span class="n">TensorOp</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">OpClass</span><span class="o">.</span><span class="n">TensorOp</span>
|
||||
<span class="k">elif</span> <span class="n">opclass</span> <span class="o">==</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">OpcodeClass</span><span class="o">.</span><span class="n">Simt</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">OpClass</span><span class="o">.</span><span class="n">Simt</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Unable to convert opcode class of type </span><span class="si">{</span><span class="n">opclass</span><span class="si">}</span><span class="s2"> to Python-bound CUTLASS opcode class."</span><span class="p">)</span></div>
|
||||
|
||||
|
||||
<span class="n">_math_operation_value_map</span> <span class="o">=</span> <span class="p">{</span><span class="n">x</span><span class="o">.</span><span class="n">value</span><span class="p">:</span> <span class="n">x</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">MathOperation</span><span class="p">}</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="backend_math_operation"><a class="viewcode-back" href="../../../cutlass.utils.html#cutlass.utils.datatypes.backend_math_operation">[docs]</a><span class="k">def</span> <span class="nf">backend_math_operation</span><span class="p">(</span><span class="n">math_op</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">MathOperation</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="n">math_op</span><span class="o">.</span><span class="n">value</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">_math_operation_value_map</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Unable to convert math operation of type </span><span class="si">{</span><span class="n">math_op</span><span class="si">}</span><span class="s2"> to backend math operation."</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">_math_operation_value_map</span><span class="p">[</span><span class="n">math_op</span><span class="o">.</span><span class="n">value</span><span class="p">]</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="construct_backend_td"><a class="viewcode-back" href="../../../cutlass.utils.html#cutlass.utils.datatypes.construct_backend_td">[docs]</a><span class="k">def</span> <span class="nf">construct_backend_td</span><span class="p">(</span><span class="n">td</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">TileDescription</span><span class="p">,</span>
|
||||
<span class="n">kernel_schedule</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">KernelScheduleType</span><span class="p">)</span> <span class="o">-></span> <span class="n">TileDescription</span><span class="p">:</span>
|
||||
<span class="n">mi</span> <span class="o">=</span> <span class="n">td</span><span class="o">.</span><span class="n">math_instruction</span>
|
||||
<span class="n">backend_mi</span> <span class="o">=</span> <span class="n">MathInstruction</span><span class="p">(</span>
|
||||
<span class="n">mi</span><span class="o">.</span><span class="n">instruction_shape</span><span class="p">,</span>
|
||||
<span class="n">binding_type</span><span class="p">(</span><span class="n">mi</span><span class="o">.</span><span class="n">element_a</span><span class="p">),</span>
|
||||
<span class="n">binding_type</span><span class="p">(</span><span class="n">mi</span><span class="o">.</span><span class="n">element_b</span><span class="p">),</span>
|
||||
<span class="n">binding_type</span><span class="p">(</span><span class="n">mi</span><span class="o">.</span><span class="n">element_accumulator</span><span class="p">),</span>
|
||||
<span class="n">binding_opclass</span><span class="p">(</span><span class="n">mi</span><span class="o">.</span><span class="n">opcode_class</span><span class="p">),</span>
|
||||
<span class="n">backend_math_operation</span><span class="p">(</span><span class="n">mi</span><span class="o">.</span><span class="n">math_operation</span><span class="p">)</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">TileDescription</span><span class="p">(</span><span class="n">td</span><span class="o">.</span><span class="n">threadblock_shape</span><span class="p">,</span> <span class="n">td</span><span class="o">.</span><span class="n">stages</span><span class="p">,</span> <span class="n">td</span><span class="o">.</span><span class="n">warp_count</span><span class="p">,</span>
|
||||
<span class="n">backend_mi</span><span class="p">,</span> <span class="n">td</span><span class="o">.</span><span class="n">cluster_shape</span><span class="p">,</span> <span class="n">kernel_schedule</span><span class="p">)</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="td_from_profiler_op"><a class="viewcode-back" href="../../../cutlass.utils.html#cutlass.utils.datatypes.td_from_profiler_op">[docs]</a><span class="k">def</span> <span class="nf">td_from_profiler_op</span><span class="p">(</span><span class="n">op</span><span class="p">)</span> <span class="o">-></span> <span class="n">TileDescription</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Converts the profiler's TileDescription in ``op`` into the backend TileDescription</span>
|
||||
|
||||
<span class="sd"> :param op: profiler Operation</span>
|
||||
|
||||
<span class="sd"> :returns: backend TileDescription</span>
|
||||
<span class="sd"> :rtype: cutlass.backend.TileDescription</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="n">schedule</span> <span class="o">=</span> <span class="n">op</span><span class="o">.</span><span class="n">kernel_schedule</span> <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">op</span><span class="p">,</span> <span class="s1">'kernel_schedule'</span><span class="p">)</span> <span class="k">else</span> <span class="kc">None</span>
|
||||
<span class="k">return</span> <span class="n">construct_backend_td</span><span class="p">(</span><span class="n">op</span><span class="o">.</span><span class="n">tile_description</span><span class="p">,</span> <span class="n">schedule</span><span class="p">)</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="td_from_profiler_td"><a class="viewcode-back" href="../../../cutlass.utils.html#cutlass.utils.datatypes.td_from_profiler_td">[docs]</a><span class="k">def</span> <span class="nf">td_from_profiler_td</span><span class="p">(</span><span class="n">td</span><span class="p">:</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">backend</span><span class="o">.</span><span class="n">TileDescription</span><span class="p">)</span> <span class="o">-></span> <span class="n">TileDescription</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> Converts the profiler's TileDescription into the backend TileDescription</span>
|
||||
|
||||
<span class="sd"> :param td: profiler TileDescription</span>
|
||||
<span class="sd"> :type td: cutlass.TileDescription</span>
|
||||
|
||||
<span class="sd"> :returns: backend TileDescription</span>
|
||||
<span class="sd"> :rtype: cutlass.backend.TileDescription</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">return</span> <span class="n">construct_backend_td</span><span class="p">(</span><span class="n">td</span><span class="p">,</span> <span class="n">kernel_schedule</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span></div>
|
||||
</pre></div>
|
||||
</article>
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
<div class="related-pages">
|
||||
|
||||
|
||||
</div>
|
||||
<div class="bottom-of-page">
|
||||
<div class="left-details">
|
||||
<div class="copyright">
|
||||
Copyright © 2023, NVIDIA
|
||||
</div>
|
||||
Made with <a href="https://www.sphinx-doc.org/">Sphinx</a> and <a class="muted-link" href="https://pradyunsg.me">@pradyunsg</a>'s
|
||||
|
||||
<a href="https://github.com/pradyunsg/furo">Furo</a>
|
||||
|
||||
</div>
|
||||
<div class="right-details">
|
||||
<div class="icons">
|
||||
<a class="muted-link " href="https://github.com/NVIDIA/cutlass" aria-label="GitHub">
|
||||
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 16 16">
|
||||
<path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z"></path>
|
||||
</svg>
|
||||
</a>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</footer>
|
||||
</div>
|
||||
<aside class="toc-drawer no-toc">
|
||||
|
||||
|
||||
|
||||
</aside>
|
||||
</div>
|
||||
</div><script data-url_root="../../../" id="documentation_options" src="../../../_static/documentation_options.js"></script>
|
||||
<script src="../../../_static/doctools.js"></script>
|
||||
<script src="../../../_static/sphinx_highlight.js"></script>
|
||||
<script src="../../../_static/scripts/furo.js"></script>
|
||||
<script src="../../../_static/clipboard.min.js"></script>
|
||||
<script src="../../../_static/copybutton.js"></script>
|
||||
<script src="../../../_static/tabs.js"></script>
|
||||
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
293
python/docs/_modules/index.html
Normal file
293
python/docs/_modules/index.html
Normal file
@ -0,0 +1,293 @@
|
||||
<!doctype html>
|
||||
<html class="no-js" lang="en">
|
||||
<head><meta charset="utf-8"/>
|
||||
<meta name="viewport" content="width=device-width,initial-scale=1"/>
|
||||
<meta name="color-scheme" content="light dark"><link rel="index" title="Index" href="../genindex.html" /><link rel="search" title="Search" href="../search.html" />
|
||||
<link rel="canonical" href="docs/_modules/index.html" />
|
||||
|
||||
<!-- Generated with Sphinx 6.1.3 and Furo 2023.03.27 -->
|
||||
<title>Overview: module code - CUTLASS Python</title>
|
||||
<link rel="stylesheet" type="text/css" href="../_static/pygments.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../_static/styles/furo.css?digest=fad236701ea90a88636c2a8c73b44ae642ed2a53" />
|
||||
<link rel="stylesheet" type="text/css" href="../_static/copybutton.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../_static/tabs.css" />
|
||||
<link rel="stylesheet" type="text/css" href="../_static/styles/furo-extensions.css?digest=30d1aed668e5c3a91c3e3bf6a60b675221979f0e" />
|
||||
|
||||
|
||||
|
||||
|
||||
<style>
|
||||
body {
|
||||
--color-code-background: #eeffcc;
|
||||
--color-code-foreground: black;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
@media not print {
|
||||
body[data-theme="dark"] {
|
||||
--color-code-background: #272822;
|
||||
--color-code-foreground: #f8f8f2;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
@media (prefers-color-scheme: dark) {
|
||||
body:not([data-theme="light"]) {
|
||||
--color-code-background: #272822;
|
||||
--color-code-foreground: #f8f8f2;
|
||||
--color-brand-primary: #76B900;
|
||||
--color-brand-content: #76B900;
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
</style></head>
|
||||
<body>
|
||||
|
||||
<script>
|
||||
document.body.dataset.theme = localStorage.getItem("theme") || "auto";
|
||||
</script>
|
||||
|
||||
|
||||
<svg xmlns="http://www.w3.org/2000/svg" style="display: none;">
|
||||
<symbol id="svg-toc" viewBox="0 0 24 24">
|
||||
<title>Contents</title>
|
||||
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 1024 1024">
|
||||
<path d="M408 442h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8zm-8 204c0 4.4 3.6 8 8 8h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56zm504-486H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zm0 632H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zM115.4 518.9L271.7 642c5.8 4.6 14.4.5 14.4-6.9V388.9c0-7.4-8.5-11.5-14.4-6.9L115.4 505.1a8.74 8.74 0 0 0 0 13.8z"/>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-menu" viewBox="0 0 24 24">
|
||||
<title>Menu</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-menu">
|
||||
<line x1="3" y1="12" x2="21" y2="12"></line>
|
||||
<line x1="3" y1="6" x2="21" y2="6"></line>
|
||||
<line x1="3" y1="18" x2="21" y2="18"></line>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-arrow-right" viewBox="0 0 24 24">
|
||||
<title>Expand</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-chevron-right">
|
||||
<polyline points="9 18 15 12 9 6"></polyline>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-sun" viewBox="0 0 24 24">
|
||||
<title>Light mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather-sun">
|
||||
<circle cx="12" cy="12" r="5"></circle>
|
||||
<line x1="12" y1="1" x2="12" y2="3"></line>
|
||||
<line x1="12" y1="21" x2="12" y2="23"></line>
|
||||
<line x1="4.22" y1="4.22" x2="5.64" y2="5.64"></line>
|
||||
<line x1="18.36" y1="18.36" x2="19.78" y2="19.78"></line>
|
||||
<line x1="1" y1="12" x2="3" y2="12"></line>
|
||||
<line x1="21" y1="12" x2="23" y2="12"></line>
|
||||
<line x1="4.22" y1="19.78" x2="5.64" y2="18.36"></line>
|
||||
<line x1="18.36" y1="5.64" x2="19.78" y2="4.22"></line>
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-moon" viewBox="0 0 24 24">
|
||||
<title>Dark mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-moon">
|
||||
<path stroke="none" d="M0 0h24v24H0z" fill="none" />
|
||||
<path d="M12 3c.132 0 .263 0 .393 0a7.5 7.5 0 0 0 7.92 12.446a9 9 0 1 1 -8.313 -12.454z" />
|
||||
</svg>
|
||||
</symbol>
|
||||
<symbol id="svg-sun-half" viewBox="0 0 24 24">
|
||||
<title>Auto light/dark mode</title>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-shadow">
|
||||
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
|
||||
<circle cx="12" cy="12" r="9" />
|
||||
<path d="M13 12h5" />
|
||||
<path d="M13 15h4" />
|
||||
<path d="M13 18h1" />
|
||||
<path d="M13 9h4" />
|
||||
<path d="M13 6h1" />
|
||||
</svg>
|
||||
</symbol>
|
||||
</svg>
|
||||
|
||||
<input type="checkbox" class="sidebar-toggle" name="__navigation" id="__navigation">
|
||||
<input type="checkbox" class="sidebar-toggle" name="__toc" id="__toc">
|
||||
<label class="overlay sidebar-overlay" for="__navigation">
|
||||
<div class="visually-hidden">Hide navigation sidebar</div>
|
||||
</label>
|
||||
<label class="overlay toc-overlay" for="__toc">
|
||||
<div class="visually-hidden">Hide table of contents sidebar</div>
|
||||
</label>
|
||||
|
||||
|
||||
|
||||
<div class="page">
|
||||
<header class="mobile-header">
|
||||
<div class="header-left">
|
||||
<label class="nav-overlay-icon" for="__navigation">
|
||||
<div class="visually-hidden">Toggle site navigation sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-menu"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
<div class="header-center">
|
||||
<a href="../index.html"><div class="brand">CUTLASS Python</div></a>
|
||||
</div>
|
||||
<div class="header-right">
|
||||
<div class="theme-toggle-container theme-toggle-header">
|
||||
<button class="theme-toggle">
|
||||
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
|
||||
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
|
||||
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
|
||||
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
|
||||
</button>
|
||||
</div>
|
||||
<label class="toc-overlay-icon toc-header-icon no-toc" for="__toc">
|
||||
<div class="visually-hidden">Toggle table of contents sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
</header>
|
||||
<aside class="sidebar-drawer">
|
||||
<div class="sidebar-container">
|
||||
|
||||
<div class="sidebar-sticky"><a class="sidebar-brand" href="../index.html">
|
||||
|
||||
<div class="sidebar-logo-container">
|
||||
<img class="sidebar-logo only-light" src="../_static/cutlass-logo-small.png" alt="Light Logo"/>
|
||||
<img class="sidebar-logo only-dark" src="../_static/cutlass-logo-small.png" alt="Dark Logo"/>
|
||||
</div>
|
||||
|
||||
<span class="sidebar-brand-text">CUTLASS Python</span>
|
||||
|
||||
</a><form class="sidebar-search-container" method="get" action="../search.html" role="search">
|
||||
<input class="sidebar-search" placeholder="Search" name="q" aria-label="Search">
|
||||
<input type="hidden" name="check_keywords" value="yes">
|
||||
<input type="hidden" name="area" value="default">
|
||||
</form>
|
||||
<div id="searchbox"></div><div class="sidebar-scroll"><div class="sidebar-tree">
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../index.html">Home</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../install.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../externals/00_basic_gemm.html">Getting Started</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../contribute.html">Contributing</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python Documentation:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../modules.html">CUTLASS Python API</a><input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" role="switch" type="checkbox"/><label for="toctree-checkbox-1"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l2 has-children"><a class="reference internal" href="../cutlass.html">CUTLASS</a><input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" role="switch" type="checkbox"/><label for="toctree-checkbox-2"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../cutlass.emit.html">Emitters</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../cutlass.op.html">Operations</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../cutlass.utils.html">Utilities</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Examples and Tutorials:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../examples.html">Examples</a><input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" role="switch" type="checkbox"/><label for="toctree-checkbox-3"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../externals/00_basic_gemm.html">Basic GEMM</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../externals/01_epilogue.html">Epilogue</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../externals/02_pytorch_extension_grouped_gemm.html">PyTorch Extension</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Reference:</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/cutlass">Github</a></li>
|
||||
</ul>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</aside>
|
||||
<div class="main">
|
||||
<div class="content">
|
||||
<div class="article-container">
|
||||
<a href="#" class="back-to-top muted-link">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24">
|
||||
<path d="M13 20h-2V8l-5.5 5.5-1.42-1.42L12 4.16l7.92 7.92-1.42 1.42L13 8v12z"></path>
|
||||
</svg>
|
||||
<span>Back to top</span>
|
||||
</a>
|
||||
<div class="content-icon-container">
|
||||
<div class="theme-toggle-container theme-toggle-content">
|
||||
<button class="theme-toggle">
|
||||
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
|
||||
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
|
||||
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
|
||||
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
|
||||
</button>
|
||||
</div>
|
||||
<label class="toc-overlay-icon toc-content-icon no-toc" for="__toc">
|
||||
<div class="visually-hidden">Toggle table of contents sidebar</div>
|
||||
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
|
||||
</label>
|
||||
</div>
|
||||
<article role="main">
|
||||
<h1>All modules for which code is available</h1>
|
||||
<ul><li><a href="cutlass/emit/pytorch.html">cutlass.emit.pytorch</a></li>
|
||||
<li><a href="cutlass/epilogue.html">cutlass.epilogue</a></li>
|
||||
<li><a href="cutlass/library_defaults.html">cutlass.library_defaults</a></li>
|
||||
<li><a href="cutlass/op/gemm.html">cutlass.op.gemm</a></li>
|
||||
<li><a href="cutlass/op/gemm_grouped.html">cutlass.op.gemm_grouped</a></li>
|
||||
<li><a href="cutlass/op/op.html">cutlass.op.op</a></li>
|
||||
<li><a href="cutlass/swizzle.html">cutlass.swizzle</a></li>
|
||||
<li><a href="cutlass/utils/check.html">cutlass.utils.check</a></li>
|
||||
<li><a href="cutlass/utils/datatypes.html">cutlass.utils.datatypes</a></li>
|
||||
</ul>
|
||||
</article>
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
<div class="related-pages">
|
||||
|
||||
|
||||
</div>
|
||||
<div class="bottom-of-page">
|
||||
<div class="left-details">
|
||||
<div class="copyright">
|
||||
Copyright © 2023, NVIDIA
|
||||
</div>
|
||||
Made with <a href="https://www.sphinx-doc.org/">Sphinx</a> and <a class="muted-link" href="https://pradyunsg.me">@pradyunsg</a>'s
|
||||
|
||||
<a href="https://github.com/pradyunsg/furo">Furo</a>
|
||||
|
||||
</div>
|
||||
<div class="right-details">
|
||||
<div class="icons">
|
||||
<a class="muted-link " href="https://github.com/NVIDIA/cutlass" aria-label="GitHub">
|
||||
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 16 16">
|
||||
<path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z"></path>
|
||||
</svg>
|
||||
</a>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</footer>
|
||||
</div>
|
||||
<aside class="toc-drawer no-toc">
|
||||
|
||||
|
||||
|
||||
</aside>
|
||||
</div>
|
||||
</div><script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
|
||||
<script src="../_static/doctools.js"></script>
|
||||
<script src="../_static/sphinx_highlight.js"></script>
|
||||
<script src="../_static/scripts/furo.js"></script>
|
||||
<script src="../_static/clipboard.min.js"></script>
|
||||
<script src="../_static/copybutton.js"></script>
|
||||
<script src="../_static/tabs.js"></script>
|
||||
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
Reference in New Issue
Block a user