CUTLASS 2.10 (#615)

Co-authored-by: Aniket Shivam <ashivam@nvidia.com>
This commit is contained in:
ANIKET SHIVAM
2022-09-03 15:48:46 -07:00
committed by GitHub
parent ca23ff7924
commit b72cbf957d
289 changed files with 43708 additions and 2513 deletions

View File

@ -0,0 +1,75 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief In-memory compiled artifact cache
*/
#include <pybind11/pybind11.h>
#include <string>
#include <unordered_map>
namespace py = pybind11;
namespace cutlass {
struct CompileCache {
public:
CompileCache() = default;
~CompileCache() = default;
using Cache = std::unordered_map<std::string, py::object>;
/// Check if the kernel has already been compiled
py::object at(const std::string &kernel) {
auto item = cache_.find(kernel);
if (item != cache_.end()) {
return item->second;
}
return py::none();
}
/// Insert a new compiled kernel for new configuration
void insert(const std::string &kernel, const py::object &compiled_kernel){
cache_.emplace(kernel, compiled_kernel);
}
const int64_t size() const { return cache_.size(); }
/// Clear the cache
void clear() { cache_.clear(); }
private:
Cache cache_;
};
} // namespace cutlass

View File

@ -0,0 +1,181 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief binding cutlass C++ APIs to python
*/
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "builtin_types.h"
#include "device_launch_parameters.h"
#include "stddef.h"
#include "cutlass/cutlass.h"
#include "include/conv/convolution.h"
#include "include/gemm/gemm.h"
#include "include/types.h"
#include "include/layout/layout.h"
#include "include/tensor_coord.h"
#include "include/arch.h"
#include "include/tensor_ref_view.h"
#include "include/swizzling.h"
#include "test/conv/convolution.h"
#include "test/gemm/gemm.h"
// Data Types
#include "library.h"
// compiler
#include "compiler.h"
namespace py = pybind11;
PYBIND11_MODULE(cutlass, m) {
// module doc
m.doc() = "cutlass C++ binding";
//
// Bind data type
//
bind_cutlass_types(m);
//
// Bind layout
//
bind_layout(m);
//
// Bind tensor coord
//
bind_tensor_coord(m);
//
// Bind tensor ref
//
bind_tensor_refs_and_views(m);
//
// Bind opcode
//
bind_opcode(m);
//
// Bind convolution
//
py::module_ conv_submodule = m.def_submodule("conv");
bind_convolution(conv_submodule);
//
// Bind gemm
//
py::module_ gemm_submodule = m.def_submodule("gemm");
bind_gemm(gemm_submodule);
//
// Bind swizzling
//
bind_threadblock_swizzle(m);
//
// Bind test units
//
py::module_ test = m.def_submodule("test");
py::module_ test_conv = test.def_submodule("conv");
bind_convolution_test(test_conv);
py::module_ test_gemm = test.def_submodule("gemm");
bind_gemm_test(test_gemm);
// data types
py::enum_<cutlass::DataType>(m, "dtype")
.value("b1", cutlass::DataType::kB1)
.value("u2", cutlass::DataType::kU2)
.value("u4", cutlass::DataType::kU4)
.value("u8", cutlass::DataType::kU8)
.value("u16", cutlass::DataType::kU16)
.value("u32", cutlass::DataType::kU32)
.value("u64", cutlass::DataType::kU64)
.value("s2", cutlass::DataType::kS2)
.value("s4", cutlass::DataType::kS4)
.value("s16", cutlass::DataType::kS16)
.value("s64", cutlass::DataType::kS64)
.value("cf16", cutlass::DataType::kCF16)
.value("cbf16", cutlass::DataType::kCBF16)
.value("cf32", cutlass::DataType::kCF32)
.value("ctf32", cutlass::DataType::kCTF32)
.value("cf64", cutlass::DataType::kCF64)
.value("cs2", cutlass::DataType::kCS2)
.value("cs4", cutlass::DataType::kCS4)
.value("cs8", cutlass::DataType::kCS8)
.value("cs16", cutlass::DataType::kCS16)
.value("cs32", cutlass::DataType::kCS32)
.value("cs64", cutlass::DataType::kCS64)
.value("cu2", cutlass::DataType::kCU2)
.value("cu4", cutlass::DataType::kCU4)
.value("cu8", cutlass::DataType::kCU8)
.value("cu16", cutlass::DataType::kCU16)
.value("cu32", cutlass::DataType::kCU32)
.value("cu64", cutlass::DataType::kCU64)
.value("invalid", cutlass::DataType::kInvalid);
// layout types
py::enum_<cutlass::LayoutType>(m, "layout")
.value("ColumnMajorInterleaved2", cutlass::LayoutType::kColumnMajorInterleaved2)
.value("RowMajorInterleaved2", cutlass::LayoutType::kRowMajorInterleaved2)
.value("ColumnMajorInterleaved64", cutlass::LayoutType::kColumnMajorInterleaved64)
.value("RowMajorInterleaved64", cutlass::LayoutType::kRowMajorInterleaved64)
.value("TensorNDHWC", cutlass::LayoutType::kTensorNDHWC)
.value("TensorNCHW", cutlass::LayoutType::kTensorNCHW)
.value("TensorNGHWC", cutlass::LayoutType::kTensorNGHWC)
.value("TensorNC64HW64", cutlass::LayoutType::kTensorNC64HW64)
.value("TensorC64RSK64", cutlass::LayoutType::kTensorC64RSK64);
// transform types
py::enum_<cutlass::ComplexTransform>(m, "complex_transform")
.value("none", cutlass::ComplexTransform::kNone)
.value("conj", cutlass::ComplexTransform::kConjugate);
//
// Compiler
//
py::class_<cutlass::CompileCache>(m, "CompileCache")
.def(py::init<>())
.def("at", &cutlass::CompileCache::at)
.def("insert", &cutlass::CompileCache::insert)
.def("size", &cutlass::CompileCache::size)
.def("clear", &cutlass::CompileCache::clear);
}

View File

@ -0,0 +1,59 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind opcode classes to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/arch/mma.h"
namespace py = pybind11;
namespace cutlass {
enum class OpcodeClass {
kSimt, kTensorOp, kWmmaTensorOp, kSparseTensorOp
};
}
void bind_opcode(py::module &m) {
py::enum_<cutlass::OpcodeClass>(m, "OpClass",
R"pbdoc(classification of math operators)pbdoc")
.value("Simt", cutlass::OpcodeClass::kSimt,
R"pbdoc(Tag classifying math operators as thread-level operations)pbdoc")
.value("TensorOp", cutlass::OpcodeClass::kTensorOp,
R"pbdoc(Tag classifing operators as Tensor Core operations)pbdoc")
.value("WmmaTensorOp", cutlass::OpcodeClass::kWmmaTensorOp,
R"pbdoc(Tag classifing operators as WMMA Tensor Core operations)pbdoc")
.value("SparseTensorOp", cutlass::OpcodeClass::kSparseTensorOp,
R"pbdoc(Tag classifing operators as sparseTensor Core operations)pbdoc");
}

View File

@ -0,0 +1,102 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind Convolution problem sizes to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/conv/conv2d_problem_size.h"
namespace py = pybind11;
void bind_conv_problem_size(py::module &m) {
//
// Conv2d Problem Size:
// include/cutlass/conv/conv2d_problem_sizd.h
//
py::class_<cutlass::conv::Conv2dProblemSize>(m, "Conv2dProblemSize")
// constructors
.def(py::init<int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, cutlass::conv::Mode, int, int>())
.def(py::init<cutlass::Tensor4DCoord, cutlass::Tensor4DCoord, cutlass::Tensor4DCoord, cutlass::MatrixCoord, cutlass::MatrixCoord, cutlass::conv::Mode, int, int>())
// attribute accessors
.def_readwrite("N", &cutlass::conv::Conv2dProblemSize::N)
.def_readwrite("H", &cutlass::conv::Conv2dProblemSize::H)
.def_readwrite("W", &cutlass::conv::Conv2dProblemSize::W)
.def_readwrite("C", &cutlass::conv::Conv2dProblemSize::C)
.def_readwrite("P", &cutlass::conv::Conv2dProblemSize::P)
.def_readwrite("Q", &cutlass::conv::Conv2dProblemSize::Q)
.def_readwrite("K", &cutlass::conv::Conv2dProblemSize::K)
.def_readwrite("R", &cutlass::conv::Conv2dProblemSize::R)
.def_readwrite("S", &cutlass::conv::Conv2dProblemSize::S)
.def_readwrite("pad_h", &cutlass::conv::Conv2dProblemSize::pad_h)
.def_readwrite("pad_w", &cutlass::conv::Conv2dProblemSize::pad_w)
.def_readwrite("stride_h", &cutlass::conv::Conv2dProblemSize::stride_h)
.def_readwrite("stride_w", &cutlass::conv::Conv2dProblemSize::stride_w)
.def_readwrite("dilation_h", &cutlass::conv::Conv2dProblemSize::dilation_h)
.def_readwrite("dilation_w", &cutlass::conv::Conv2dProblemSize::dilation_w)
.def_readwrite("mode", &cutlass::conv::Conv2dProblemSize::mode)
.def_readwrite("split_k_slices", &cutlass::conv::Conv2dProblemSize::split_k_slices)
.def_readwrite("groups", &cutlass::conv::Conv2dProblemSize::groups)
// functions
.def("reset_split_k_slices", &cutlass::conv::Conv2dProblemSize::reset_split_k_slices)
.def("activation_extent", &cutlass::conv::Conv2dProblemSize::activation_extent)
.def("filter_extent", &cutlass::conv::Conv2dProblemSize::filter_extent)
.def("output_extent", &cutlass::conv::Conv2dProblemSize::output_extent)
.def("activation_size", &cutlass::conv::Conv2dProblemSize::activation_size)
.def("filter_size", &cutlass::conv::Conv2dProblemSize::filter_size)
.def("output_size", &cutlass::conv::Conv2dProblemSize::output_size);
// Get tensor size
m.def("implicit_gemm_tensor_a_size", py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&>(&cutlass::conv::implicit_gemm_tensor_a_size));
m.def("implicit_gemm_tensor_b_size", py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&>(&cutlass::conv::implicit_gemm_tensor_b_size));
m.def("implicit_gemm_tensor_c_size", py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&>(&cutlass::conv::implicit_gemm_tensor_c_size));
// Get tensor extent
m.def("implicit_gemm_tensor_a_extent",
py::overload_cast<
cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&
>(&cutlass::conv::implicit_gemm_tensor_a_extent));
m.def("implicit_gemm_tensor_b_extent",
py::overload_cast<
cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&
>(&cutlass::conv::implicit_gemm_tensor_b_extent));
m.def("implicit_gemm_tensor_c_extent",
py::overload_cast<
cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&
>(&cutlass::conv::implicit_gemm_tensor_c_extent));
m.def("implicit_gemm_problem_size", py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize &>(&cutlass::conv::implicit_gemm_problem_size));
}

View File

@ -0,0 +1,91 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind convolution related enum types to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "conv_problem_size.h"
#include "host.h"
#include "cutlass/conv/convolution.h"
namespace py = pybind11;
void bind_convolution(py::module &m) {
//
// Enumerate types
// cutlass/include/cutlass/conv/convolution.h
//
/// Convolutional operator
py::enum_<cutlass::conv::Operator>(m, "Operator", R"pbdoc(Convolutional operator)pbdoc")
.value("fprop", cutlass::conv::Operator::kFprop, "Forward propagation")
.value("dgrad", cutlass::conv::Operator::kDgrad, "Activation grad")
.value("wgrad", cutlass::conv::Operator::kWgrad, "Weight grad");
/// Distinguishes convolution from cross correlation
py::enum_<cutlass::conv::Mode>(m, "Mode")
.value("cross_correlation", cutlass::conv::Mode::kCrossCorrelation)
.value("convolution", cutlass::conv::Mode::kConvolution);
/// Selects among several implementation variants trading off performance with simplicity
py::enum_<cutlass::conv::IteratorAlgorithm>(m, "IteratorAlgorithm",
R"pbdoc(Selects among several implementation variants trading off performance with simplicity)pbdoc")
.value("analytic", cutlass::conv::IteratorAlgorithm::kAnalytic, R"pbdoc(functionally correct in all cases but lower performance)pbdoc")
.value("optimized", cutlass::conv::IteratorAlgorithm::kOptimized, R"pbdoc(optimized for R <= 32, S <= 32 and unity-stride dgrad)pbdoc")
.value("fixed_channels", cutlass::conv::IteratorAlgorithm::kFixedChannels, R"pbdoc(Analytic algorithm optimized for fixed channel count (C == AccessSize))pbdoc")
.value("few_channels", cutlass::conv::IteratorAlgorithm::kFewChannels, R"pbdoc(Analytic algorithm optimized for few channels (C divisible by AccessSize))pbdoc");
/// Distinguishes among partial specializations that accelerate certain problems where convolution
/// stride is unit.
py::enum_<cutlass::conv::StrideSupport>(m, "StrideSupport",
R"pbdoc(Distinguishes among partial specializations that accelerate certain problems where convolution
stride is unit.)pbdoc")
.value("strided", cutlass::conv::StrideSupport::kStrided, R"pbdoc(arbitrary convolution stride)pbdoc")
.value("unity", cutlass::conv::StrideSupport::kUnity, R"pbdoc(unit convolution stride)pbdoc");
/// Identifies split-K mode
py::enum_<cutlass::conv::SplitKMode>(m, "SplitKMode")
.value("None", cutlass::conv::SplitKMode::kNone)
.value("Serial", cutlass::conv::SplitKMode::kSerial)
.value("Parallel", cutlass::conv::SplitKMode::kParallel);
// Conv problem sizes
bind_conv_problem_size(m);
//
// host helper functions
//
py::module_ host_submodule = m.def_submodule("host");
bind_conv_host_helper(host_submodule);
}

View File

@ -0,0 +1,54 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind conv host helpers to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/util/host_reorder.h"
#include "cutlass/layout/tensor.h"
namespace py = pybind11;
void bind_conv_host_helper(py::module &m) {
/// reorder operand B for interleaved layout
m.def("reorder_convK", [](
cutlass::TensorRef<int8_t, cutlass::layout::TensorCxRSKx<32>> dest,
cutlass::TensorRef<int8_t, cutlass::layout::TensorCxRSKx<32>> src,
cutlass::conv::Operator conv_op, const cutlass::conv::Conv2dProblemSize & problem_size) {
cutlass::gemm::GemmCoord implicit_problem_size = cutlass::conv::implicit_gemm_problem_size(conv_op, problem_size);
cutlass::reorder_convK<32>(dest, src, implicit_problem_size);
});
}

View File

@ -0,0 +1,77 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind gemm related enum types to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/gemm/gemm.h"
#include "host.h"
namespace py = pybind11;
void bind_gemm(py::module &m) {
//
// Enumerate types
// cutlass/gemm/gemm.h
py::enum_<cutlass::gemm::GemmUniversalMode>(m, "Mode")
.value("Gemm", cutlass::gemm::GemmUniversalMode::kGemm, "Ordinary GEMM & GEMM Split-K serial")
.value("GemmSplitKParallel", cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, "GEMM Split-K parallel")
.value("Batched", cutlass::gemm::GemmUniversalMode::kBatched, "Batched GEMM")
.value("Array", cutlass::gemm::GemmUniversalMode::kArray)
.value("Invalid", cutlass::gemm::GemmUniversalMode::kInvalid);
/// GemmCoord is a structure that specifies a location within the coordiate space of a GEMM problem
py::class_<cutlass::gemm::GemmCoord>(m, "GemmCoord")
.def(py::init<int, int, int>())
.def("m", py::overload_cast<>(&cutlass::gemm::GemmCoord::m))
.def("n", py::overload_cast<>(&cutlass::gemm::GemmCoord::n))
.def("k", py::overload_cast<>(&cutlass::gemm::GemmCoord::k))
// get tensor coords
.def("mk",
[](const cutlass::gemm::GemmCoord & problem_size) {
return cutlass::MatrixCoord(problem_size.mk());
})
.def("kn",
[](const cutlass::gemm::GemmCoord & problem_size) {
return cutlass::MatrixCoord(problem_size.kn());
})
.def("mn",
[](const cutlass::gemm::GemmCoord & problem_size) {
return cutlass::MatrixCoord(problem_size.mn());
});
py::module_ host_submodule = m.def_submodule("host");
bind_gemm_host_helper(host_submodule);
}

View File

@ -0,0 +1,47 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind gemm host helpers to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/util/host_reorder.h"
#include "cutlass/layout/tensor.h"
namespace py = pybind11;
void bind_gemm_host_helper(py::module &m) {
m.def("reorder_column", &cutlass::reorder_column<32, int8_t, cutlass::layout::RowMajorInterleaved<32>>);
m.def("reorder_column", &cutlass::reorder_column<32, int8_t, cutlass::layout::ColumnMajorInterleaved<32>>);
}

View File

@ -0,0 +1,47 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind CUTLASS layouts to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "tensor.h"
#include "matrix.h"
namespace py = pybind11;
void bind_layout(py::module &m) {
bind_tensor_layout(m);
bind_matrix_layout(m);
}

View File

@ -0,0 +1,87 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind Matrix layouts to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/layout/matrix.h"
namespace py = pybind11;
void bind_matrix_layout(py::module &m) {
//
// Matrix layouts
// cutlass/layout/matrix.h
//
py::class_<cutlass::layout::RowMajor>(m, "RowMajor", R"pbdoc(
Mapping function for row-major matrices.
)pbdoc")
.def_static("packed", &cutlass::layout::RowMajor::packed,
py::arg("extent"),
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc")
.def("stride", [](const cutlass::layout::RowMajor & layout){
return layout.stride().at(0);
}, R"pbdoc(Returns the stride of the layout)pbdoc");
py::class_<cutlass::layout::ColumnMajor>(m, "ColumnMajor", R"pbdoc(
Mapping function for column-major matrices.
)pbdoc")
.def_static("packed", &cutlass::layout::ColumnMajor::packed,
py::arg("extent"),
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc" )
.def("stride", [](const cutlass::layout::ColumnMajor & layout){
return layout.stride().at(0);
}, R"pbdoc(Returns the stride of the layout)pbdoc");
py::class_<cutlass::layout::RowMajorInterleaved<32>>(m, "RowMajorInterleaved32",
R"pbdoc(Mapping function for interleaved matrices. Matrix is structured
as row-major arrangement of fixed-size columns 32)pbdoc")
.def_static("packed", &cutlass::layout::RowMajorInterleaved<32>::packed,
py::arg("extent"),
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc")
.def("stride", [](const cutlass::layout::RowMajorInterleaved<32> & layout){
return layout.stride().at(0);
}, R"pbdoc(Returns the stride of the layout)pbdoc");
py::class_<cutlass::layout::ColumnMajorInterleaved<32>>(m, "ColumnMajorInterleaved32",
R"pbdoc(Mapping function for interleaved matrices. Matrix is structured
as column-major arrangement of fixed-size rows 32)pbdoc")
.def_static("packed", &cutlass::layout::ColumnMajorInterleaved<32>::packed,
py::arg("extent"),
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc")
.def("stride", [](const cutlass::layout::ColumnMajorInterleaved<32> & layout){
return layout.stride().at(0);
}, R"pbdoc(Returns the stride of the layout)pbdoc");
}

View File

@ -0,0 +1,74 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind Tensor layouts to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/layout/tensor.h"
namespace py = pybind11;
void bind_tensor_layout(py::module &m) {
//
// Tensor layouts
// cutlass/include/cutlass/layout/tensor.h
//
/// Mapping function for 4-D NHWC tensors.
py::class_<cutlass::layout::TensorNHWC>(m, "TensorNHWC",
R"pbdoc(Mapping function for 4-D NHWC tensors)pbdoc")
.def_static("packed", &cutlass::layout::TensorNHWC::packed,
py::arg("extent"),
R"pbdoc(Helper returns a layout to a tightly packed NHWC tensor)pbdoc")
.def("stride", py::overload_cast<>(&cutlass::layout::TensorNHWC::stride),
R"pbdoc(Returns the stride of the layout)pbdoc");
/// Mapping function for 4-D NC/xHWx tensors.
py::class_<cutlass::layout::TensorNCxHWx<32>>(m, "TensorNC32HW32",
R"pbdoc(Mapping function for 4-D NC/32HW32 tensors)pbdoc")
.def_static("packed", &cutlass::layout::TensorNCxHWx<32>::packed,
py::arg("extent"),
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc")
.def("stride", py::overload_cast<>(&cutlass::layout::TensorNCxHWx<32>::stride),
R"pbdoc(Returns the stride of the layout)pbdoc");
/// Mapping function for 4-D CxRSKx tensors.
py::class_<cutlass::layout::TensorCxRSKx<32>>(m, "TensorC32RSK32",
R"pbdoc(Mapping function for 4-D C32RSK32 tensors)pbdoc")
.def_static("packed", &cutlass::layout::TensorCxRSKx<32>::packed,
py::arg("extent"),
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc")
.def("stride", py::overload_cast<>(&cutlass::layout::TensorCxRSKx<32>::stride),
R"pbdoc(Returns the stride of the layout)pbdoc");
}

View File

@ -0,0 +1,152 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind threadblock swizzling to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/conv/threadblock/threadblock_swizzle.h"
#include <boost/core/demangle.hpp>
#include <cuda_runtime.h>
namespace py = pybind11;
template<typename T>
void bind_identity_swizzle(py::module & m, std::string name) {
py::class_<T>(m, name.c_str(),
R"pbdoc(Threadblock swizzling function for GEMMs)pbdoc")
.def(py::init<>())
.def("get_tiled_shape",
py::overload_cast<cutlass::gemm::GemmCoord, cutlass::gemm::GemmCoord, int>(
&T::get_tiled_shape, py::const_
), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
R"pbdoc(Returns the shape of the problem in units of logical tiles
:param problem_size: gemm(M, N, K)
:type problem_size: :class:`cutlass.gemm.GemmCoord`
)pbdoc")
.def("get_tiled_shape",
py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&, cutlass::gemm::GemmCoord, int>(
&T::get_tiled_shape, py::const_
), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
R"pbdoc(Returns the shape of the problem in units of logical tiles
:param problem_size: Implicit gemm problem size conv_operator(NPQK, NHWC, KRSC)
:type problem_size: :class:`cutlass.gemm.GemmCoord`)
)pbdoc")
.def("get_tiled_shape",
py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv3dProblemSize&, cutlass::gemm::GemmCoord, int>(
&T::get_tiled_shape, py::const_
), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
R"pbdoc(Returns the shape of the problem in units of logical tiles
:param problem_size: Implicit gemm problem size conv_operator(NZPQK, NDHWC, KTRSC)
:type problem_size: :class:`cutlass.gemm.GemmCoord`)
)pbdoc")
// TODO: the returned dim3 is not usable in python
.def("get_grid_shape", &T::get_grid_shape,
py::arg("tiled_shape"),
R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc")
.def("tag", [](const T & swizzle){
return boost::core::demangle(typeid(T).name());
}, R"pbdoc(Returns the c++ name of the swizzling for code emittion)pbdoc");
}
template<typename T>
void bind_swizzle(py::module & m, std::string name, std::string doc) {
py::class_<T>(m, name.c_str(), doc.c_str())
.def(py::init<>())
.def("get_tiled_shape",
py::overload_cast<cutlass::gemm::GemmCoord, cutlass::gemm::GemmCoord, int>(
&T::get_tiled_shape, py::const_
), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
R"pbdoc(Returns the shape of the problem in units of logical tiles
:param problem_size: gemm(M, N, K)
:type problem_size: :class:`cutlass.gemm.GemmCoord`
)pbdoc")
.def("get_grid_shape", &T::get_grid_shape,
py::arg("tiled_shape"),
R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc")
.def("tag", [](const T & swizzle){
return boost::core::demangle(typeid(T).name());
}, R"pbdoc(Returns the c++ name of the swizzling for code emittion)pbdoc");
}
template<typename T>
void bind_dgrad_swizzle(py::module & m, std::string name) {
py::class_<T>(m, name.c_str(),
R"pbdoc(Threadblock swizzling function for strided dgrad convolution)pbdoc")
.def(py::init<>())
.def("get_tiled_shape",
py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&, cutlass::gemm::GemmCoord, int>(
&T::get_tiled_shape, py::const_
), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
R"pbdoc(Returns the shape of the problem in units of logical tiles
:param problem_size: Implicit gemm problem size conv_operator(NPQK, NHWC, KRSC)
:type problem_size: :class:`cutlass.gemm.GemmCoord`)
)pbdoc")
.def("get_grid_shape", [](const T & swizzle, cutlass::gemm::GemmCoord tiled_shape) {
return dim3(tiled_shape.m(), tiled_shape.n(), tiled_shape.k());
}, py::arg("tiled_shape"),
R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc")
.def("tag", [](const T & swizzle){
return boost::core::demangle(typeid(T).name());
}, R"pbdoc(Returns the c++ name of the swizzling for code emittion)pbdoc");
}
void bind_threadblock_swizzle(py::module &m) {
py::class_<dim3>(m, "dim3",
R"pbdoc(A int3 type xyz contains three integers)pbdoc")
.def(py::init<int, int, int>(),
py::arg("x"), py::arg("y"), py::arg("z"))
.def_readwrite("x", &dim3::x, R"pbdoc(get value x)pbdoc")
.def_readwrite("y", &dim3::y, R"pbdoc(get value y)pbdoc")
.def_readwrite("z", &dim3::z, R"pbdoc(get value z)pbdoc");
bind_identity_swizzle<cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>>(m, "IdentitySwizzle1");
bind_identity_swizzle<cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>>(m, "IdentitySwizzle2");
bind_identity_swizzle<cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>>(m, "IdentitySwizzle4");
bind_identity_swizzle<cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>>(m, "IdentitySwizzle8");
bind_swizzle<cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle>(m, "HorizontalSwizzle", R"pbdoc(Threadblock swizzling function for GEMMs)pbdoc");
bind_swizzle<cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle>(m, "BatchedIdentitySwizzle", R"pbdoc(Threadblock swizzling function for batched GEMMs)pbdoc");
bind_dgrad_swizzle<cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>>(m, "StridedDgradIdentitySwizzle1");
bind_dgrad_swizzle<cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>>(m, "StridedDgradIdentitySwizzle4");
bind_dgrad_swizzle<cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle>(m, "StridedDgradHorizontalSwizzle");
}

View File

@ -0,0 +1,72 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind Tensor Coord to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/tensor_coord.h"
namespace py = pybind11;
void bind_tensor_coord(py::module &m) {
//
// Tensor Coords
// cutlass/include/cutlass/tensor_coord.h
//
/// Defines a canonical 4D coordinate used by tensor operations.
py::class_<cutlass::Tensor4DCoord>(m, "Tensor4DCoord",
R"pbdoc(Defines a canonical 4D coordinate used by tensor operations)pbdoc")
.def(py::init<int, int, int, int>(),
py::arg("n"), py::arg("h"), py::arg("w"), py::arg("c"),
R"pbdoc(Helper to construct from N, H, W, and C)pbdoc");
py::class_<cutlass::Coord<3>>(m, "Tensor3DCoord",
R"pbdoc(Defines a canonical 3D coordinate used by tensor operations)pbdoc")
.def("at", py::overload_cast<int>(&cutlass::Coord<3>::at),
py::arg("dim"),
R"pbdoc(Gets the index of a given Coord element)pbdoc");
// Matrix Size
py::class_<cutlass::MatrixCoord>(m, "MatrixCoord",
R"pbdoc(MatrixCoord wraps Coord<2, int> to provide a helper for accessing named dimensions. Classes
expecting a coordinate in the rank=2 index space of a matrix should use MatrixCoord.)pbdoc")
.def(py::init<int, int>(),
py::arg("row"), py::arg("column"), R"pbdoc(Helper to construct from a row and column)pbdoc")
.def("row", py::overload_cast<>(&cutlass::MatrixCoord::row),
R"pbdoc(Returns the row of the coordinate)pbdoc")
.def("column", py::overload_cast<>(&cutlass::MatrixCoord::column),
R"pbdoc(Returns the column of the coordinate)pbdoc");
}

View File

@ -0,0 +1,102 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSE<cutlass::TensorRef<QUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind TensorRef and View to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/tensor_ref.h"
#include "cutlass/tensor_view.h"
#include "types.h"
template<typename T, typename L, typename TF>
void bind_tensor_ref_view(py::module &m, std::string name) {
py::class_<cutlass::TensorRef<T, L>>(m, ("TensorRef" + name).c_str())
.def("__init__", [](cutlass::TensorRef<T, L>& tensor_ref, int64_t address, const L& layout_ ) {
T* ptr = reinterpret_cast< T*>(address);
new (&tensor_ref) cutlass::TensorRef<T, L>(ptr, layout_);
})
.def("data", [](cutlass::TensorRef<T, L>& tensor_ref) {
T* ptr = tensor_ref.data();
return int64_t(ptr);
})
.def("layout", py::overload_cast<>(&cutlass::TensorRef<T, L>::layout));
m.def("get_tensor_ref", [](int64_t address, TF data, const L& layout_) {
T* ptr = reinterpret_cast<T*>(address);
cutlass::TensorRef<T, L> tensor_ref = cutlass::TensorRef<T, L>(ptr, layout_);
return tensor_ref;
});
py::class_<cutlass::TensorView<T, L>>(m, ("TensorView" + name).c_str())
.def(py::init<const cutlass::TensorRef<T, L>&, const typename L::TensorCoord &>());
}
void bind_tensor_refs_and_views(py::module &m) {
/// float
bind_tensor_ref_view<float, cutlass::layout::RowMajor, cutlass::float32>(m, "F32RowMajor");
bind_tensor_ref_view<float, cutlass::layout::ColumnMajor, cutlass::float32>(m, "F32ColumnMajor");
bind_tensor_ref_view<float, cutlass::layout::TensorNHWC, cutlass::float32>(m, "F32NHWC");
/// double
bind_tensor_ref_view<double, cutlass::layout::RowMajor, cutlass::float64>(m, "F64RowMajor");
bind_tensor_ref_view<double, cutlass::layout::ColumnMajor, cutlass::float64>(m, "F64ColumnMajor");
bind_tensor_ref_view<double, cutlass::layout::TensorNHWC, cutlass::float64>(m, "F64NHWC");
// half_t
bind_tensor_ref_view<cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t>(m, "F16RowMajor");
bind_tensor_ref_view<cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t>(m, "F16ColumnMajor");
bind_tensor_ref_view<cutlass::half_t, cutlass::layout::TensorNHWC, cutlass::half_t>(m, "F16NHWC");
// bfloat16
bind_tensor_ref_view<cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t>(m, "BF16RowMajor");
bind_tensor_ref_view<cutlass::bfloat16_t, cutlass::layout::ColumnMajor, cutlass::bfloat16_t>(m, "BF16ColumnMajor");
bind_tensor_ref_view<cutlass::bfloat16_t, cutlass::layout::TensorNHWC, cutlass::bfloat16_t>(m, "BF16NHWC");
// int8_t
bind_tensor_ref_view<int8_t, cutlass::layout::RowMajorInterleaved<32>, cutlass::int8>(m, "S8RowMajorInterleaved32");
bind_tensor_ref_view<int8_t, cutlass::layout::ColumnMajorInterleaved<32>, cutlass::int8>(m, "S8ColumnMajorInterleaved32");
bind_tensor_ref_view<int8_t, cutlass::layout::RowMajor, cutlass::int8>(m, "S8RowMajor");
bind_tensor_ref_view<int8_t, cutlass::layout::ColumnMajor, cutlass::int8>(m, "S8ColumnMajor");
bind_tensor_ref_view<int8_t, cutlass::layout::TensorNHWC, cutlass::int8>(m, "S8NHWC");
bind_tensor_ref_view<int8_t, cutlass::layout::TensorNCxHWx<32>, cutlass::int8>(m, "S8NC32HW32");
bind_tensor_ref_view<int8_t, cutlass::layout::TensorCxRSKx<32>, cutlass::int8>(m, "S8C32RSK32");
// int32_t
bind_tensor_ref_view<int32_t, cutlass::layout::RowMajor, cutlass::int32>(m, "S32RowMajor");
bind_tensor_ref_view<int32_t, cutlass::layout::ColumnMajor, cutlass::int32>(m, "S32ColumnMajor");
bind_tensor_ref_view<int32_t, cutlass::layout::TensorNHWC, cutlass::int32>(m, "S32NHWC");
}

View File

@ -0,0 +1,146 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind CUTLASS types to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/half.h"
namespace py = pybind11;
namespace cutlass {
/// IEEE 32-bit signed integer
struct alignas(1) int8 {
int8_t storage;
explicit int8(int x) {
storage = int8_t(x);
}
explicit int8(float x) {
storage = int8_t(x);
}
int8_t c_value(){return storage;}
};
/// IEEE 32-bit signed integer
struct alignas(4) int32 {
int storage;
explicit int32(int x) {
storage = x;
}
explicit int32(float x) {
storage = int(x);
}
int c_value(){return storage;}
};
/// IEEE single-precision floating-point type
struct alignas(4) float32 {
float storage;
explicit float32(float x) {
storage = x;
}
explicit float32(int x) {
storage = float(x);
}
float c_value(){return storage;}
};
/// IEEE double-precision floating-point type
struct alignas(4) float64 {
double storage;
explicit float64(float x) {
storage = double(x);
}
explicit float64(int x) {
storage = double(x);
}
double c_value(){return storage;}
};
}
void bind_cutlass_types(py::module &m) {
// s8
py::class_<cutlass::int8>(m, "int8")
.def(py::init<float>())
.def(py::init<int>())
.def_readwrite("storage", &cutlass::int8::storage)
.def("value", &cutlass::int8::c_value);
// s32
py::class_<cutlass::int32>(m, "int32")
.def(py::init<float>())
.def(py::init<int>())
.def_readwrite("storage", &cutlass::int32::storage)
.def("value", &cutlass::int32::c_value);
// f16
py::class_<cutlass::half_t>(m, "float16")
.def(py::init<float>())
.def(py::init<double>())
.def(py::init<int>())
.def(py::init<unsigned>())
.def_readwrite("storage", &cutlass::half_t::storage)
.def("value", [](const cutlass::half_t& value) {return value;});
// bf16
py::class_<cutlass::bfloat16_t>(m, "bfloat16")
.def(py::init<float>())
.def(py::init<int>())
.def_readwrite("storage", &cutlass::bfloat16_t::storage)
.def("value", [](const cutlass::bfloat16_t& value) {return value;});
// f32
py::class_<cutlass::float32>(m, "float32")
.def(py::init<float>())
.def(py::init<int>())
.def_readwrite("storage", &cutlass::float32::storage)
.def("value", &cutlass::float32::c_value);
// tf32
py::class_<cutlass::tfloat32_t>(m, "tfloat32")
.def(py::init<float>())
.def(py::init<int>())
.def_readwrite("storage", &cutlass::tfloat32_t::storage)
.def("value", [](const cutlass::tfloat32_t& value) {return value;});
// f64
py::class_<cutlass::float64>(m, "float64")
.def(py::init<float>())
.def(py::init<int>())
.def_readwrite("storage", &cutlass::float64::storage)
.def("value", &cutlass::float64::c_value);
}

View File

@ -0,0 +1,32 @@
#include <cutlass/complex.h>
namespace cutlass {
/// ENUM class for datatypes
enum class DataType {
kB1, kU2, kU4, kU8,
kU16, kU32, kU64, kS2,
kS4, kS8, kS16, kS32,
kS64, kF16, kBF16, kF32,
kTF32, kF64, kCF16, kCBF16,
kCF32, kCTF32, kCF64, kCS2,
kCS4, kCS8, kCS16, kCS32,
kCS64, kCU2, kCU4, kCU8,
kCU16, kCU32, kCU64, kInvalid
};
/// ENUM class for LayoutTypes
enum class LayoutType {
kColumnMajor, kRowMajor,
kColumnMajorInterleaved2, kRowMajorInterleaved2,
kColumnMajorInterleaved32, kRowMajorInterleaved32,
kColumnMajorInterleaved64, kRowMajorInterleaved64,
kTensorNHWC, kTensorNDHWC, kTensorNCHW, kTensorNGHWC,
kTensorNC32HW32, kTensorNC64HW64, kTensorC32RSK32,
kTensorC64RSK64
};
/// ENUM class for opcode class
} // namespace cutlass

View File

@ -0,0 +1,54 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind convolution problems to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "unit/conv/device/conv2d_problems.h"
#include "cutlass/conv/conv2d_problem_size.h"
namespace py = pybind11;
PYBIND11_MAKE_OPAQUE(std::vector<cutlass::conv::Conv2dProblemSize>);
void bind_conv_problem_size_test(py::module &m) {
py::bind_vector<std::vector<cutlass::conv::Conv2dProblemSize>>(m, "Conv2dProblemVector")
.def("size", &std::vector<cutlass::conv::Conv2dProblemSize>::size);
// Get Conv2d problem sizes
py::class_<test::conv::device::TestbedConv2dProblemSizes>(m, "TestbedConv2dProblemSizes")
.def(py::init<int>())
.def_readonly("conv2d_default_sizes", &test::conv::device::TestbedConv2dProblemSizes::conv2d_default_sizes);
}

View File

@ -0,0 +1,49 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind convolution related types to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "conv_problems.h"
#include "host.h"
namespace py = pybind11;
void bind_convolution_test(py::module &m) {
// Conv problem sizes
bind_conv_problem_size_test(m);
py::module_ host_submodule = m.def_submodule("host");
bind_conv_host_references(host_submodule);
}

View File

@ -0,0 +1,180 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind Convolution host test helpers to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "unit/conv/device/cache_testbed_output.h"
#include "cutlass/util/reference/host/convolution.h"
#include "cutlass/util/reference/host/tensor_compare.h"
namespace py = pybind11;
template<typename Ta, typename La, typename Tb, typename Lb, typename Tc, typename Lc, typename Tacc, typename Te>
void bind_conv2d_host(py::module &m) {
m.def("conv2d", \
&cutlass::reference::host::Conv2d< \
Ta, La, Tb, Lb, Tc, Lc, Te, Tacc>);
m.def("CreateCachedConv2dTestKey", &test::conv::device::CreateCachedConv2dTestKey<Ta, La, Tb, Lb, Tc, Lc, Tacc, Te>);
}
template<typename Ta, typename La, typename Tb, typename Lb, typename Tc, typename Lc, typename Tacc, typename Te>
void bind_conv2d_host_sat(py::module &m) {
m.def("conv2d", \
&cutlass::reference::host::Conv2d< \
Ta, La, Tb, Lb, Tc, Lc, Te, Tacc, cutlass::NumericConverterClamp<Tc, Te>>);
m.def("CreateCachedConv2dTestKey", &test::conv::device::CreateCachedConv2dTestKey<Ta, La, Tb, Lb, Tc, Lc, Tacc, Te>);
}
template<typename Ta, typename Tb, typename Tc, typename Tacc, typename Te>
void bind_conv2d_host_nhwc(py::module &m) {
bind_conv2d_host<
Ta, cutlass::layout::TensorNHWC,
Tb, cutlass::layout::TensorNHWC,
Tc, cutlass::layout::TensorNHWC,
Tacc, Te>(m);
}
template<typename Ta, typename Tb, typename Tc, typename Tacc, typename Te>
void bind_conv2d_host_nc32hw32(py::module &m) {
bind_conv2d_host_sat<
Ta, cutlass::layout::TensorNCxHWx<32>,
Tb, cutlass::layout::TensorCxRSKx<32>,
Tc, cutlass::layout::TensorNCxHWx<32>,
Tacc, Te>(m);
}
template<typename T, typename Layout>
void bind_tensor_equals(py::module &m) {
m.def("equals", py::overload_cast<
const cutlass::TensorView<T, Layout>&, const cutlass::TensorView<T, Layout>&>(
&cutlass::reference::host::TensorEquals<T, Layout>
));
}
#define BIND_TENSOR_HASH(Element, Layout) { \
m.def("TensorHash", &test::conv::device::TensorHash<Element, Layout>, py::arg("view"), py::arg("hash") = test::conv::device::CRC32(), py::arg("crc")=uint32_t()); \
}
void bind_conv_host_references(py::module &m) {
//
// Conv2d reference on host
// tools/util/include/cutlass/util/reference/host/convolution.h
/// double
bind_conv2d_host_nhwc<double, double, double, double, double>(m);
/// float
bind_conv2d_host_nhwc<float, float, float, float, float>(m);
/// half
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, cutlass::half_t, cutlass::half_t, cutlass::half_t>(m);
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, cutlass::half_t, float, cutlass::half_t>(m);
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, cutlass::half_t, float, float>(m);
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, cutlass::half_t, cutlass::half_t, float>(m);
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, float, cutlass::half_t, cutlass::half_t>(m);
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, float, float, cutlass::half_t>(m);
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, float, float, float>(m);
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, float, cutlass::half_t, float>(m);
/// bfloat16
bind_conv2d_host_nhwc<cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass::bfloat16_t, float, cutlass::bfloat16_t>(m);
bind_conv2d_host_nhwc<cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass::bfloat16_t, float, float>(m);
bind_conv2d_host_nhwc<cutlass::bfloat16_t, cutlass::bfloat16_t, float, float, cutlass::bfloat16_t>(m);
bind_conv2d_host_nhwc<cutlass::bfloat16_t, cutlass::bfloat16_t, float, float, float>(m);
/// s8
bind_conv2d_host_nhwc<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
bind_conv2d_host_nhwc<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
bind_conv2d_host_nhwc<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
bind_conv2d_host_nhwc<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
bind_conv2d_host_nhwc<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_conv2d_host_nhwc<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_conv2d_host_nhwc<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_conv2d_host_nhwc<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_conv2d_host_nc32hw32<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
bind_conv2d_host_nc32hw32<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
bind_conv2d_host_nc32hw32<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
bind_conv2d_host_nc32hw32<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
bind_conv2d_host_nc32hw32<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_conv2d_host_nc32hw32<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_conv2d_host_nc32hw32<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_conv2d_host_nc32hw32<int8_t, int8_t, int32_t, int32_t, float>(m);
//
// Compare whether two tensors are equal
//
/// double
bind_tensor_equals<double, cutlass::layout::TensorNHWC>(m);
/// float
bind_tensor_equals<float, cutlass::layout::TensorNHWC>(m);
/// half
bind_tensor_equals<cutlass::half_t, cutlass::layout::TensorNHWC>(m);
/// bfloat16
bind_tensor_equals<cutlass::bfloat16_t, cutlass::layout::TensorNHWC>(m);
/// s32
bind_tensor_equals<int32_t, cutlass::layout::TensorNHWC>(m);
bind_tensor_equals<int32_t, cutlass::layout::TensorNCxHWx<32>>(m);
/// s8
bind_tensor_equals<int8_t, cutlass::layout::TensorNHWC>(m);
bind_tensor_equals<int8_t, cutlass::layout::TensorNCxHWx<32>>(m);
/// Cache
py::class_<test::conv::device::CachedTestKey>(m, "CachedTestKey")
.def(py::init<>())
.def(py::init<std::string, std::string, std::string, uint32_t, uint32_t, uint32_t>());
py::class_<test::conv::device::CachedTestResult>(m, "CachedTestResult")
.def(py::init<>())
.def(py::init<uint32_t>())
.def_readwrite("D", &test::conv::device::CachedTestResult::D);
py::class_<test::conv::device::CachedTestResultListing>(m, "CachedTestResultListing")
.def(py::init<const std::string &>())
.def("find", &test::conv::device::CachedTestResultListing::find)
.def("append", &test::conv::device::CachedTestResultListing::append)
.def("write", &test::conv::device::CachedTestResultListing::write);
py::class_<test::conv::device::CRC32>(m, "CRC32")
.def(py::init<>());
BIND_TENSOR_HASH(double, cutlass::layout::TensorNHWC)
BIND_TENSOR_HASH(float, cutlass::layout::TensorNHWC);
BIND_TENSOR_HASH(cutlass::half_t, cutlass::layout::TensorNHWC);
BIND_TENSOR_HASH(cutlass::bfloat16_t, cutlass::layout::TensorNHWC);
BIND_TENSOR_HASH(int32_t, cutlass::layout::TensorNHWC);
BIND_TENSOR_HASH(int8_t, cutlass::layout::TensorNCxHWx<32>);
}

View File

@ -0,0 +1,45 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind gemm test to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "host.h"
namespace py = pybind11;
void bind_gemm_test(py::module &m) {
py::module_ host_submodule = m.def_submodule("host");
bind_gemm_host_reference(host_submodule);
}

View File

@ -0,0 +1,431 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind gemm test host functions to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/cutlass.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/host_reorder.h"
#include "cutlass/functional.h"
namespace py = pybind11;
template<
typename ElementA, typename LayoutA,
typename ElementB, typename LayoutB,
typename ElementC, typename LayoutC,
typename AccumulatorType, typename ComputeType,
typename InnerProductOp>
void bind_host_gemm_saturate(py::module &m) {
m.def("gemm_saturate", py::overload_cast<
cutlass::gemm::GemmCoord, ComputeType,
cutlass::TensorRef<ElementA, LayoutA>,
cutlass::TensorRef<ElementB, LayoutB>,
ComputeType,
cutlass::TensorRef<ElementC, LayoutC>,
cutlass::TensorRef<ElementC, LayoutC>,
AccumulatorType>(
&cutlass::reference::host::compute_gemm<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
ComputeType,
AccumulatorType,
InnerProductOp,
cutlass::NumericConverterClamp<ElementC, AccumulatorType>>
));
}
template<
typename ElementA, typename LayoutA,
typename ElementB, typename LayoutB,
typename ElementC, typename LayoutC,
typename AccumulatorType, typename ComputeType,
typename InnerProductOp>
void bind_host_gemm(py::module &m) {
m.def("gemm", py::overload_cast<
cutlass::gemm::GemmCoord, ComputeType,
cutlass::TensorRef<ElementA, LayoutA>,
cutlass::TensorRef<ElementB, LayoutB>,
ComputeType,
cutlass::TensorRef<ElementC, LayoutC>,
cutlass::TensorRef<ElementC, LayoutC>,
AccumulatorType>(
&cutlass::reference::host::compute_gemm<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
ComputeType,
AccumulatorType,
InnerProductOp,
cutlass::NumericConverter<ElementC, AccumulatorType>>
));
}
template<
typename ElementA, typename ElementB, typename ElementC,
typename AccumulatorType, typename ComputeType>
void bind_host_gemm_multiply_add(py::module &m) {
bind_host_gemm<
ElementA, cutlass::layout::RowMajor,
ElementB, cutlass::layout::RowMajor,
ElementC, cutlass::layout::RowMajor,
ComputeType, AccumulatorType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::ColumnMajor,
ElementB, cutlass::layout::RowMajor,
ElementC, cutlass::layout::RowMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::RowMajor,
ElementB, cutlass::layout::ColumnMajor,
ElementC, cutlass::layout::RowMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::RowMajor,
ElementB, cutlass::layout::RowMajor,
ElementC, cutlass::layout::ColumnMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::RowMajor,
ElementB, cutlass::layout::ColumnMajor,
ElementC, cutlass::layout::ColumnMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::ColumnMajor,
ElementB, cutlass::layout::RowMajor,
ElementC, cutlass::layout::ColumnMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::ColumnMajor,
ElementB, cutlass::layout::ColumnMajor,
ElementC, cutlass::layout::RowMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::ColumnMajor,
ElementB, cutlass::layout::ColumnMajor,
ElementC, cutlass::layout::ColumnMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
}
template<
typename ElementA, typename ElementB, typename ElementC,
typename AccumulatorType, typename ComputeType>
void bind_host_gemm_multiply_add_saturate(py::module &m) {
bind_host_gemm_saturate<
ElementA, cutlass::layout::RowMajor,
ElementB, cutlass::layout::RowMajor,
ElementC, cutlass::layout::RowMajor,
ComputeType, AccumulatorType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::ColumnMajor,
ElementB, cutlass::layout::RowMajor,
ElementC, cutlass::layout::RowMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::RowMajor,
ElementB, cutlass::layout::ColumnMajor,
ElementC, cutlass::layout::RowMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::RowMajor,
ElementB, cutlass::layout::RowMajor,
ElementC, cutlass::layout::ColumnMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::RowMajor,
ElementB, cutlass::layout::ColumnMajor,
ElementC, cutlass::layout::ColumnMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::ColumnMajor,
ElementB, cutlass::layout::RowMajor,
ElementC, cutlass::layout::ColumnMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::ColumnMajor,
ElementB, cutlass::layout::ColumnMajor,
ElementC, cutlass::layout::RowMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::ColumnMajor,
ElementB, cutlass::layout::ColumnMajor,
ElementC, cutlass::layout::ColumnMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
}
template<
typename ElementA, typename ElementB, typename ElementC,
typename AccumulatorType, typename ComputeType>
void bind_host_gemm_multiply_add_interleaved(py::module &m) {
bind_host_gemm<
ElementA, cutlass::layout::RowMajorInterleaved<32>,
ElementB, cutlass::layout::RowMajorInterleaved<32>,
ElementC, cutlass::layout::RowMajorInterleaved<32>,
ComputeType, AccumulatorType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
ElementB, cutlass::layout::RowMajorInterleaved<32>,
ElementC, cutlass::layout::RowMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::RowMajorInterleaved<32>,
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
ElementC, cutlass::layout::RowMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::RowMajorInterleaved<32>,
ElementB, cutlass::layout::RowMajorInterleaved<32>,
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::RowMajorInterleaved<32>,
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
ElementB, cutlass::layout::RowMajorInterleaved<32>,
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
ElementC, cutlass::layout::RowMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
}
template<
typename ElementA, typename ElementB, typename ElementC,
typename AccumulatorType, typename ComputeType>
void bind_host_gemm_multiply_add_saturate_interleaved(py::module &m) {
bind_host_gemm_saturate<
ElementA, cutlass::layout::RowMajorInterleaved<32>,
ElementB, cutlass::layout::RowMajorInterleaved<32>,
ElementC, cutlass::layout::RowMajorInterleaved<32>,
ComputeType, AccumulatorType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
ElementB, cutlass::layout::RowMajorInterleaved<32>,
ElementC, cutlass::layout::RowMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::RowMajorInterleaved<32>,
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
ElementC, cutlass::layout::RowMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::RowMajorInterleaved<32>,
ElementB, cutlass::layout::RowMajorInterleaved<32>,
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::RowMajorInterleaved<32>,
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
ElementB, cutlass::layout::RowMajorInterleaved<32>,
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
ElementC, cutlass::layout::RowMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
}
#define BIND_TENSOR_EQUAL(Element, Layout) { \
m.def("equals", py::overload_cast< \
const cutlass::TensorView<Element, Layout>&, const cutlass::TensorView<Element, Layout>&>( \
&cutlass::reference::host::TensorEquals<Element, Layout>)); \
}
void bind_gemm_host_reference(py::module &m) {
/// double
bind_host_gemm_multiply_add<double, double, double, double, double>(m);
/// float
bind_host_gemm_multiply_add<float, float, float, float, float>(m);
/// half_t
bind_host_gemm_multiply_add<cutlass::half_t, cutlass::half_t, cutlass::half_t, cutlass::half_t, cutlass::half_t>(m);
bind_host_gemm_multiply_add<cutlass::half_t, cutlass::half_t, cutlass::half_t, float, float>(m);
bind_host_gemm_multiply_add<cutlass::half_t, cutlass::half_t, float, cutlass::half_t, cutlass::half_t>(m);
bind_host_gemm_multiply_add<cutlass::half_t, cutlass::half_t, float, float, float>(m);
/// bfloat16
bind_host_gemm_multiply_add<cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass::bfloat16_t, float, float>(m);
bind_host_gemm_multiply_add<cutlass::bfloat16_t, cutlass::bfloat16_t, float, float, float>(m);
/// s8
bind_host_gemm_multiply_add<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
bind_host_gemm_multiply_add<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
bind_host_gemm_multiply_add<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
bind_host_gemm_multiply_add<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
bind_host_gemm_multiply_add<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_host_gemm_multiply_add<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_host_gemm_multiply_add<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_host_gemm_multiply_add<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int32_t, int32_t, float>(m);
// float
BIND_TENSOR_EQUAL(float, cutlass::layout::RowMajor);
BIND_TENSOR_EQUAL(float, cutlass::layout::ColumnMajor);
// double
BIND_TENSOR_EQUAL(double, cutlass::layout::RowMajor);
BIND_TENSOR_EQUAL(double, cutlass::layout::ColumnMajor);
// half_t
BIND_TENSOR_EQUAL(cutlass::half_t, cutlass::layout::RowMajor);
BIND_TENSOR_EQUAL(cutlass::half_t, cutlass::layout::ColumnMajor);
// bfloat16
BIND_TENSOR_EQUAL(cutlass::bfloat16_t, cutlass::layout::RowMajor);
BIND_TENSOR_EQUAL(cutlass::bfloat16_t, cutlass::layout::ColumnMajor);
// int32_t
BIND_TENSOR_EQUAL(int32_t, cutlass::layout::RowMajor);
BIND_TENSOR_EQUAL(int32_t, cutlass::layout::ColumnMajor);
// int8_t
BIND_TENSOR_EQUAL(int8_t, cutlass::layout::RowMajor);
BIND_TENSOR_EQUAL(int8_t, cutlass::layout::ColumnMajor);
BIND_TENSOR_EQUAL(int8_t, cutlass::layout::RowMajorInterleaved<32>);
BIND_TENSOR_EQUAL(int8_t, cutlass::layout::ColumnMajorInterleaved<32>);
}