75
tools/library/scripts/pycutlass/src/cpp/compiler.h
Normal file
75
tools/library/scripts/pycutlass/src/cpp/compiler.h
Normal file
@ -0,0 +1,75 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief In-memory compiled artifact cache
|
||||
*/
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
struct CompileCache {
|
||||
public:
|
||||
CompileCache() = default;
|
||||
~CompileCache() = default;
|
||||
|
||||
using Cache = std::unordered_map<std::string, py::object>;
|
||||
|
||||
/// Check if the kernel has already been compiled
|
||||
py::object at(const std::string &kernel) {
|
||||
auto item = cache_.find(kernel);
|
||||
|
||||
if (item != cache_.end()) {
|
||||
return item->second;
|
||||
}
|
||||
return py::none();
|
||||
}
|
||||
|
||||
/// Insert a new compiled kernel for new configuration
|
||||
void insert(const std::string &kernel, const py::object &compiled_kernel){
|
||||
cache_.emplace(kernel, compiled_kernel);
|
||||
}
|
||||
|
||||
const int64_t size() const { return cache_.size(); }
|
||||
|
||||
/// Clear the cache
|
||||
void clear() { cache_.clear(); }
|
||||
|
||||
private:
|
||||
Cache cache_;
|
||||
};
|
||||
|
||||
} // namespace cutlass
|
||||
181
tools/library/scripts/pycutlass/src/cpp/cutlass.cpp
Normal file
181
tools/library/scripts/pycutlass/src/cpp/cutlass.cpp
Normal file
@ -0,0 +1,181 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief binding cutlass C++ APIs to python
|
||||
*/
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "builtin_types.h"
|
||||
#include "device_launch_parameters.h"
|
||||
#include "stddef.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "include/conv/convolution.h"
|
||||
#include "include/gemm/gemm.h"
|
||||
#include "include/types.h"
|
||||
#include "include/layout/layout.h"
|
||||
#include "include/tensor_coord.h"
|
||||
#include "include/arch.h"
|
||||
#include "include/tensor_ref_view.h"
|
||||
#include "include/swizzling.h"
|
||||
#include "test/conv/convolution.h"
|
||||
#include "test/gemm/gemm.h"
|
||||
|
||||
|
||||
// Data Types
|
||||
#include "library.h"
|
||||
|
||||
// compiler
|
||||
#include "compiler.h"
|
||||
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
||||
PYBIND11_MODULE(cutlass, m) {
|
||||
|
||||
// module doc
|
||||
m.doc() = "cutlass C++ binding";
|
||||
|
||||
//
|
||||
// Bind data type
|
||||
//
|
||||
bind_cutlass_types(m);
|
||||
|
||||
//
|
||||
// Bind layout
|
||||
//
|
||||
bind_layout(m);
|
||||
|
||||
//
|
||||
// Bind tensor coord
|
||||
//
|
||||
bind_tensor_coord(m);
|
||||
|
||||
//
|
||||
// Bind tensor ref
|
||||
//
|
||||
bind_tensor_refs_and_views(m);
|
||||
|
||||
//
|
||||
// Bind opcode
|
||||
//
|
||||
bind_opcode(m);
|
||||
|
||||
//
|
||||
// Bind convolution
|
||||
//
|
||||
py::module_ conv_submodule = m.def_submodule("conv");
|
||||
bind_convolution(conv_submodule);
|
||||
|
||||
//
|
||||
// Bind gemm
|
||||
//
|
||||
py::module_ gemm_submodule = m.def_submodule("gemm");
|
||||
bind_gemm(gemm_submodule);
|
||||
|
||||
//
|
||||
// Bind swizzling
|
||||
//
|
||||
bind_threadblock_swizzle(m);
|
||||
|
||||
|
||||
//
|
||||
// Bind test units
|
||||
//
|
||||
py::module_ test = m.def_submodule("test");
|
||||
py::module_ test_conv = test.def_submodule("conv");
|
||||
bind_convolution_test(test_conv);
|
||||
|
||||
py::module_ test_gemm = test.def_submodule("gemm");
|
||||
bind_gemm_test(test_gemm);
|
||||
|
||||
// data types
|
||||
py::enum_<cutlass::DataType>(m, "dtype")
|
||||
.value("b1", cutlass::DataType::kB1)
|
||||
.value("u2", cutlass::DataType::kU2)
|
||||
.value("u4", cutlass::DataType::kU4)
|
||||
.value("u8", cutlass::DataType::kU8)
|
||||
.value("u16", cutlass::DataType::kU16)
|
||||
.value("u32", cutlass::DataType::kU32)
|
||||
.value("u64", cutlass::DataType::kU64)
|
||||
.value("s2", cutlass::DataType::kS2)
|
||||
.value("s4", cutlass::DataType::kS4)
|
||||
.value("s16", cutlass::DataType::kS16)
|
||||
.value("s64", cutlass::DataType::kS64)
|
||||
.value("cf16", cutlass::DataType::kCF16)
|
||||
.value("cbf16", cutlass::DataType::kCBF16)
|
||||
.value("cf32", cutlass::DataType::kCF32)
|
||||
.value("ctf32", cutlass::DataType::kCTF32)
|
||||
.value("cf64", cutlass::DataType::kCF64)
|
||||
.value("cs2", cutlass::DataType::kCS2)
|
||||
.value("cs4", cutlass::DataType::kCS4)
|
||||
.value("cs8", cutlass::DataType::kCS8)
|
||||
.value("cs16", cutlass::DataType::kCS16)
|
||||
.value("cs32", cutlass::DataType::kCS32)
|
||||
.value("cs64", cutlass::DataType::kCS64)
|
||||
.value("cu2", cutlass::DataType::kCU2)
|
||||
.value("cu4", cutlass::DataType::kCU4)
|
||||
.value("cu8", cutlass::DataType::kCU8)
|
||||
.value("cu16", cutlass::DataType::kCU16)
|
||||
.value("cu32", cutlass::DataType::kCU32)
|
||||
.value("cu64", cutlass::DataType::kCU64)
|
||||
.value("invalid", cutlass::DataType::kInvalid);
|
||||
|
||||
// layout types
|
||||
py::enum_<cutlass::LayoutType>(m, "layout")
|
||||
.value("ColumnMajorInterleaved2", cutlass::LayoutType::kColumnMajorInterleaved2)
|
||||
.value("RowMajorInterleaved2", cutlass::LayoutType::kRowMajorInterleaved2)
|
||||
.value("ColumnMajorInterleaved64", cutlass::LayoutType::kColumnMajorInterleaved64)
|
||||
.value("RowMajorInterleaved64", cutlass::LayoutType::kRowMajorInterleaved64)
|
||||
.value("TensorNDHWC", cutlass::LayoutType::kTensorNDHWC)
|
||||
.value("TensorNCHW", cutlass::LayoutType::kTensorNCHW)
|
||||
.value("TensorNGHWC", cutlass::LayoutType::kTensorNGHWC)
|
||||
.value("TensorNC64HW64", cutlass::LayoutType::kTensorNC64HW64)
|
||||
.value("TensorC64RSK64", cutlass::LayoutType::kTensorC64RSK64);
|
||||
|
||||
// transform types
|
||||
py::enum_<cutlass::ComplexTransform>(m, "complex_transform")
|
||||
.value("none", cutlass::ComplexTransform::kNone)
|
||||
.value("conj", cutlass::ComplexTransform::kConjugate);
|
||||
|
||||
//
|
||||
// Compiler
|
||||
//
|
||||
py::class_<cutlass::CompileCache>(m, "CompileCache")
|
||||
.def(py::init<>())
|
||||
.def("at", &cutlass::CompileCache::at)
|
||||
.def("insert", &cutlass::CompileCache::insert)
|
||||
.def("size", &cutlass::CompileCache::size)
|
||||
.def("clear", &cutlass::CompileCache::clear);
|
||||
|
||||
}
|
||||
59
tools/library/scripts/pycutlass/src/cpp/include/arch.h
Normal file
59
tools/library/scripts/pycutlass/src/cpp/include/arch.h
Normal file
@ -0,0 +1,59 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind opcode classes to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "cutlass/arch/mma.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace cutlass {
|
||||
enum class OpcodeClass {
|
||||
kSimt, kTensorOp, kWmmaTensorOp, kSparseTensorOp
|
||||
};
|
||||
}
|
||||
|
||||
void bind_opcode(py::module &m) {
|
||||
py::enum_<cutlass::OpcodeClass>(m, "OpClass",
|
||||
R"pbdoc(classification of math operators)pbdoc")
|
||||
.value("Simt", cutlass::OpcodeClass::kSimt,
|
||||
R"pbdoc(Tag classifying math operators as thread-level operations)pbdoc")
|
||||
.value("TensorOp", cutlass::OpcodeClass::kTensorOp,
|
||||
R"pbdoc(Tag classifing operators as Tensor Core operations)pbdoc")
|
||||
.value("WmmaTensorOp", cutlass::OpcodeClass::kWmmaTensorOp,
|
||||
R"pbdoc(Tag classifing operators as WMMA Tensor Core operations)pbdoc")
|
||||
.value("SparseTensorOp", cutlass::OpcodeClass::kSparseTensorOp,
|
||||
R"pbdoc(Tag classifing operators as sparseTensor Core operations)pbdoc");
|
||||
}
|
||||
@ -0,0 +1,102 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind Convolution problem sizes to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void bind_conv_problem_size(py::module &m) {
|
||||
//
|
||||
// Conv2d Problem Size:
|
||||
// include/cutlass/conv/conv2d_problem_sizd.h
|
||||
//
|
||||
py::class_<cutlass::conv::Conv2dProblemSize>(m, "Conv2dProblemSize")
|
||||
// constructors
|
||||
.def(py::init<int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, cutlass::conv::Mode, int, int>())
|
||||
.def(py::init<cutlass::Tensor4DCoord, cutlass::Tensor4DCoord, cutlass::Tensor4DCoord, cutlass::MatrixCoord, cutlass::MatrixCoord, cutlass::conv::Mode, int, int>())
|
||||
// attribute accessors
|
||||
.def_readwrite("N", &cutlass::conv::Conv2dProblemSize::N)
|
||||
.def_readwrite("H", &cutlass::conv::Conv2dProblemSize::H)
|
||||
.def_readwrite("W", &cutlass::conv::Conv2dProblemSize::W)
|
||||
.def_readwrite("C", &cutlass::conv::Conv2dProblemSize::C)
|
||||
.def_readwrite("P", &cutlass::conv::Conv2dProblemSize::P)
|
||||
.def_readwrite("Q", &cutlass::conv::Conv2dProblemSize::Q)
|
||||
.def_readwrite("K", &cutlass::conv::Conv2dProblemSize::K)
|
||||
.def_readwrite("R", &cutlass::conv::Conv2dProblemSize::R)
|
||||
.def_readwrite("S", &cutlass::conv::Conv2dProblemSize::S)
|
||||
.def_readwrite("pad_h", &cutlass::conv::Conv2dProblemSize::pad_h)
|
||||
.def_readwrite("pad_w", &cutlass::conv::Conv2dProblemSize::pad_w)
|
||||
.def_readwrite("stride_h", &cutlass::conv::Conv2dProblemSize::stride_h)
|
||||
.def_readwrite("stride_w", &cutlass::conv::Conv2dProblemSize::stride_w)
|
||||
.def_readwrite("dilation_h", &cutlass::conv::Conv2dProblemSize::dilation_h)
|
||||
.def_readwrite("dilation_w", &cutlass::conv::Conv2dProblemSize::dilation_w)
|
||||
.def_readwrite("mode", &cutlass::conv::Conv2dProblemSize::mode)
|
||||
.def_readwrite("split_k_slices", &cutlass::conv::Conv2dProblemSize::split_k_slices)
|
||||
.def_readwrite("groups", &cutlass::conv::Conv2dProblemSize::groups)
|
||||
// functions
|
||||
.def("reset_split_k_slices", &cutlass::conv::Conv2dProblemSize::reset_split_k_slices)
|
||||
.def("activation_extent", &cutlass::conv::Conv2dProblemSize::activation_extent)
|
||||
.def("filter_extent", &cutlass::conv::Conv2dProblemSize::filter_extent)
|
||||
.def("output_extent", &cutlass::conv::Conv2dProblemSize::output_extent)
|
||||
.def("activation_size", &cutlass::conv::Conv2dProblemSize::activation_size)
|
||||
.def("filter_size", &cutlass::conv::Conv2dProblemSize::filter_size)
|
||||
.def("output_size", &cutlass::conv::Conv2dProblemSize::output_size);
|
||||
|
||||
// Get tensor size
|
||||
m.def("implicit_gemm_tensor_a_size", py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&>(&cutlass::conv::implicit_gemm_tensor_a_size));
|
||||
m.def("implicit_gemm_tensor_b_size", py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&>(&cutlass::conv::implicit_gemm_tensor_b_size));
|
||||
m.def("implicit_gemm_tensor_c_size", py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&>(&cutlass::conv::implicit_gemm_tensor_c_size));
|
||||
|
||||
// Get tensor extent
|
||||
m.def("implicit_gemm_tensor_a_extent",
|
||||
py::overload_cast<
|
||||
cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&
|
||||
>(&cutlass::conv::implicit_gemm_tensor_a_extent));
|
||||
|
||||
m.def("implicit_gemm_tensor_b_extent",
|
||||
py::overload_cast<
|
||||
cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&
|
||||
>(&cutlass::conv::implicit_gemm_tensor_b_extent));
|
||||
|
||||
m.def("implicit_gemm_tensor_c_extent",
|
||||
py::overload_cast<
|
||||
cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&
|
||||
>(&cutlass::conv::implicit_gemm_tensor_c_extent));
|
||||
|
||||
m.def("implicit_gemm_problem_size", py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize &>(&cutlass::conv::implicit_gemm_problem_size));
|
||||
|
||||
}
|
||||
@ -0,0 +1,91 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind convolution related enum types to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "conv_problem_size.h"
|
||||
#include "host.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void bind_convolution(py::module &m) {
|
||||
//
|
||||
// Enumerate types
|
||||
// cutlass/include/cutlass/conv/convolution.h
|
||||
//
|
||||
|
||||
/// Convolutional operator
|
||||
py::enum_<cutlass::conv::Operator>(m, "Operator", R"pbdoc(Convolutional operator)pbdoc")
|
||||
.value("fprop", cutlass::conv::Operator::kFprop, "Forward propagation")
|
||||
.value("dgrad", cutlass::conv::Operator::kDgrad, "Activation grad")
|
||||
.value("wgrad", cutlass::conv::Operator::kWgrad, "Weight grad");
|
||||
|
||||
/// Distinguishes convolution from cross correlation
|
||||
py::enum_<cutlass::conv::Mode>(m, "Mode")
|
||||
.value("cross_correlation", cutlass::conv::Mode::kCrossCorrelation)
|
||||
.value("convolution", cutlass::conv::Mode::kConvolution);
|
||||
|
||||
/// Selects among several implementation variants trading off performance with simplicity
|
||||
py::enum_<cutlass::conv::IteratorAlgorithm>(m, "IteratorAlgorithm",
|
||||
R"pbdoc(Selects among several implementation variants trading off performance with simplicity)pbdoc")
|
||||
.value("analytic", cutlass::conv::IteratorAlgorithm::kAnalytic, R"pbdoc(functionally correct in all cases but lower performance)pbdoc")
|
||||
.value("optimized", cutlass::conv::IteratorAlgorithm::kOptimized, R"pbdoc(optimized for R <= 32, S <= 32 and unity-stride dgrad)pbdoc")
|
||||
.value("fixed_channels", cutlass::conv::IteratorAlgorithm::kFixedChannels, R"pbdoc(Analytic algorithm optimized for fixed channel count (C == AccessSize))pbdoc")
|
||||
.value("few_channels", cutlass::conv::IteratorAlgorithm::kFewChannels, R"pbdoc(Analytic algorithm optimized for few channels (C divisible by AccessSize))pbdoc");
|
||||
|
||||
/// Distinguishes among partial specializations that accelerate certain problems where convolution
|
||||
/// stride is unit.
|
||||
py::enum_<cutlass::conv::StrideSupport>(m, "StrideSupport",
|
||||
R"pbdoc(Distinguishes among partial specializations that accelerate certain problems where convolution
|
||||
stride is unit.)pbdoc")
|
||||
.value("strided", cutlass::conv::StrideSupport::kStrided, R"pbdoc(arbitrary convolution stride)pbdoc")
|
||||
.value("unity", cutlass::conv::StrideSupport::kUnity, R"pbdoc(unit convolution stride)pbdoc");
|
||||
|
||||
/// Identifies split-K mode
|
||||
py::enum_<cutlass::conv::SplitKMode>(m, "SplitKMode")
|
||||
.value("None", cutlass::conv::SplitKMode::kNone)
|
||||
.value("Serial", cutlass::conv::SplitKMode::kSerial)
|
||||
.value("Parallel", cutlass::conv::SplitKMode::kParallel);
|
||||
|
||||
// Conv problem sizes
|
||||
bind_conv_problem_size(m);
|
||||
|
||||
//
|
||||
// host helper functions
|
||||
//
|
||||
py::module_ host_submodule = m.def_submodule("host");
|
||||
bind_conv_host_helper(host_submodule);
|
||||
}
|
||||
54
tools/library/scripts/pycutlass/src/cpp/include/conv/host.h
Normal file
54
tools/library/scripts/pycutlass/src/cpp/include/conv/host.h
Normal file
@ -0,0 +1,54 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind conv host helpers to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "cutlass/util/host_reorder.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
||||
void bind_conv_host_helper(py::module &m) {
|
||||
|
||||
/// reorder operand B for interleaved layout
|
||||
m.def("reorder_convK", [](
|
||||
cutlass::TensorRef<int8_t, cutlass::layout::TensorCxRSKx<32>> dest,
|
||||
cutlass::TensorRef<int8_t, cutlass::layout::TensorCxRSKx<32>> src,
|
||||
cutlass::conv::Operator conv_op, const cutlass::conv::Conv2dProblemSize & problem_size) {
|
||||
cutlass::gemm::GemmCoord implicit_problem_size = cutlass::conv::implicit_gemm_problem_size(conv_op, problem_size);
|
||||
cutlass::reorder_convK<32>(dest, src, implicit_problem_size);
|
||||
});
|
||||
}
|
||||
77
tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm.h
Normal file
77
tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm.h
Normal file
@ -0,0 +1,77 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind gemm related enum types to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "host.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void bind_gemm(py::module &m) {
|
||||
//
|
||||
// Enumerate types
|
||||
// cutlass/gemm/gemm.h
|
||||
|
||||
py::enum_<cutlass::gemm::GemmUniversalMode>(m, "Mode")
|
||||
.value("Gemm", cutlass::gemm::GemmUniversalMode::kGemm, "Ordinary GEMM & GEMM Split-K serial")
|
||||
.value("GemmSplitKParallel", cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, "GEMM Split-K parallel")
|
||||
.value("Batched", cutlass::gemm::GemmUniversalMode::kBatched, "Batched GEMM")
|
||||
.value("Array", cutlass::gemm::GemmUniversalMode::kArray)
|
||||
.value("Invalid", cutlass::gemm::GemmUniversalMode::kInvalid);
|
||||
|
||||
/// GemmCoord is a structure that specifies a location within the coordiate space of a GEMM problem
|
||||
py::class_<cutlass::gemm::GemmCoord>(m, "GemmCoord")
|
||||
.def(py::init<int, int, int>())
|
||||
.def("m", py::overload_cast<>(&cutlass::gemm::GemmCoord::m))
|
||||
.def("n", py::overload_cast<>(&cutlass::gemm::GemmCoord::n))
|
||||
.def("k", py::overload_cast<>(&cutlass::gemm::GemmCoord::k))
|
||||
// get tensor coords
|
||||
.def("mk",
|
||||
[](const cutlass::gemm::GemmCoord & problem_size) {
|
||||
return cutlass::MatrixCoord(problem_size.mk());
|
||||
})
|
||||
.def("kn",
|
||||
[](const cutlass::gemm::GemmCoord & problem_size) {
|
||||
return cutlass::MatrixCoord(problem_size.kn());
|
||||
})
|
||||
.def("mn",
|
||||
[](const cutlass::gemm::GemmCoord & problem_size) {
|
||||
return cutlass::MatrixCoord(problem_size.mn());
|
||||
});
|
||||
|
||||
py::module_ host_submodule = m.def_submodule("host");
|
||||
bind_gemm_host_helper(host_submodule);
|
||||
}
|
||||
47
tools/library/scripts/pycutlass/src/cpp/include/gemm/host.h
Normal file
47
tools/library/scripts/pycutlass/src/cpp/include/gemm/host.h
Normal file
@ -0,0 +1,47 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind gemm host helpers to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "cutlass/util/host_reorder.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
||||
void bind_gemm_host_helper(py::module &m) {
|
||||
m.def("reorder_column", &cutlass::reorder_column<32, int8_t, cutlass::layout::RowMajorInterleaved<32>>);
|
||||
m.def("reorder_column", &cutlass::reorder_column<32, int8_t, cutlass::layout::ColumnMajorInterleaved<32>>);
|
||||
}
|
||||
@ -0,0 +1,47 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind CUTLASS layouts to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "tensor.h"
|
||||
#include "matrix.h"
|
||||
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void bind_layout(py::module &m) {
|
||||
bind_tensor_layout(m);
|
||||
bind_matrix_layout(m);
|
||||
}
|
||||
@ -0,0 +1,87 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind Matrix layouts to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void bind_matrix_layout(py::module &m) {
|
||||
//
|
||||
// Matrix layouts
|
||||
// cutlass/layout/matrix.h
|
||||
//
|
||||
|
||||
py::class_<cutlass::layout::RowMajor>(m, "RowMajor", R"pbdoc(
|
||||
Mapping function for row-major matrices.
|
||||
)pbdoc")
|
||||
.def_static("packed", &cutlass::layout::RowMajor::packed,
|
||||
py::arg("extent"),
|
||||
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc")
|
||||
.def("stride", [](const cutlass::layout::RowMajor & layout){
|
||||
return layout.stride().at(0);
|
||||
}, R"pbdoc(Returns the stride of the layout)pbdoc");
|
||||
|
||||
py::class_<cutlass::layout::ColumnMajor>(m, "ColumnMajor", R"pbdoc(
|
||||
Mapping function for column-major matrices.
|
||||
)pbdoc")
|
||||
.def_static("packed", &cutlass::layout::ColumnMajor::packed,
|
||||
py::arg("extent"),
|
||||
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc" )
|
||||
.def("stride", [](const cutlass::layout::ColumnMajor & layout){
|
||||
return layout.stride().at(0);
|
||||
}, R"pbdoc(Returns the stride of the layout)pbdoc");
|
||||
|
||||
py::class_<cutlass::layout::RowMajorInterleaved<32>>(m, "RowMajorInterleaved32",
|
||||
R"pbdoc(Mapping function for interleaved matrices. Matrix is structured
|
||||
as row-major arrangement of fixed-size columns 32)pbdoc")
|
||||
.def_static("packed", &cutlass::layout::RowMajorInterleaved<32>::packed,
|
||||
py::arg("extent"),
|
||||
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc")
|
||||
.def("stride", [](const cutlass::layout::RowMajorInterleaved<32> & layout){
|
||||
return layout.stride().at(0);
|
||||
}, R"pbdoc(Returns the stride of the layout)pbdoc");
|
||||
|
||||
py::class_<cutlass::layout::ColumnMajorInterleaved<32>>(m, "ColumnMajorInterleaved32",
|
||||
R"pbdoc(Mapping function for interleaved matrices. Matrix is structured
|
||||
as column-major arrangement of fixed-size rows 32)pbdoc")
|
||||
.def_static("packed", &cutlass::layout::ColumnMajorInterleaved<32>::packed,
|
||||
py::arg("extent"),
|
||||
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc")
|
||||
.def("stride", [](const cutlass::layout::ColumnMajorInterleaved<32> & layout){
|
||||
return layout.stride().at(0);
|
||||
}, R"pbdoc(Returns the stride of the layout)pbdoc");
|
||||
}
|
||||
@ -0,0 +1,74 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind Tensor layouts to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "cutlass/layout/tensor.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void bind_tensor_layout(py::module &m) {
|
||||
//
|
||||
// Tensor layouts
|
||||
// cutlass/include/cutlass/layout/tensor.h
|
||||
//
|
||||
|
||||
/// Mapping function for 4-D NHWC tensors.
|
||||
py::class_<cutlass::layout::TensorNHWC>(m, "TensorNHWC",
|
||||
R"pbdoc(Mapping function for 4-D NHWC tensors)pbdoc")
|
||||
.def_static("packed", &cutlass::layout::TensorNHWC::packed,
|
||||
py::arg("extent"),
|
||||
R"pbdoc(Helper returns a layout to a tightly packed NHWC tensor)pbdoc")
|
||||
.def("stride", py::overload_cast<>(&cutlass::layout::TensorNHWC::stride),
|
||||
R"pbdoc(Returns the stride of the layout)pbdoc");
|
||||
|
||||
/// Mapping function for 4-D NC/xHWx tensors.
|
||||
py::class_<cutlass::layout::TensorNCxHWx<32>>(m, "TensorNC32HW32",
|
||||
R"pbdoc(Mapping function for 4-D NC/32HW32 tensors)pbdoc")
|
||||
.def_static("packed", &cutlass::layout::TensorNCxHWx<32>::packed,
|
||||
py::arg("extent"),
|
||||
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc")
|
||||
.def("stride", py::overload_cast<>(&cutlass::layout::TensorNCxHWx<32>::stride),
|
||||
R"pbdoc(Returns the stride of the layout)pbdoc");
|
||||
|
||||
/// Mapping function for 4-D CxRSKx tensors.
|
||||
py::class_<cutlass::layout::TensorCxRSKx<32>>(m, "TensorC32RSK32",
|
||||
R"pbdoc(Mapping function for 4-D C32RSK32 tensors)pbdoc")
|
||||
.def_static("packed", &cutlass::layout::TensorCxRSKx<32>::packed,
|
||||
py::arg("extent"),
|
||||
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc")
|
||||
.def("stride", py::overload_cast<>(&cutlass::layout::TensorCxRSKx<32>::stride),
|
||||
R"pbdoc(Returns the stride of the layout)pbdoc");
|
||||
}
|
||||
152
tools/library/scripts/pycutlass/src/cpp/include/swizzling.h
Normal file
152
tools/library/scripts/pycutlass/src/cpp/include/swizzling.h
Normal file
@ -0,0 +1,152 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind threadblock swizzling to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
#include "cutlass/conv/threadblock/threadblock_swizzle.h"
|
||||
|
||||
#include <boost/core/demangle.hpp>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
template<typename T>
|
||||
void bind_identity_swizzle(py::module & m, std::string name) {
|
||||
py::class_<T>(m, name.c_str(),
|
||||
R"pbdoc(Threadblock swizzling function for GEMMs)pbdoc")
|
||||
.def(py::init<>())
|
||||
.def("get_tiled_shape",
|
||||
py::overload_cast<cutlass::gemm::GemmCoord, cutlass::gemm::GemmCoord, int>(
|
||||
&T::get_tiled_shape, py::const_
|
||||
), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
|
||||
R"pbdoc(Returns the shape of the problem in units of logical tiles
|
||||
|
||||
:param problem_size: gemm(M, N, K)
|
||||
:type problem_size: :class:`cutlass.gemm.GemmCoord`
|
||||
)pbdoc")
|
||||
.def("get_tiled_shape",
|
||||
py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&, cutlass::gemm::GemmCoord, int>(
|
||||
&T::get_tiled_shape, py::const_
|
||||
), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
|
||||
R"pbdoc(Returns the shape of the problem in units of logical tiles
|
||||
|
||||
:param problem_size: Implicit gemm problem size conv_operator(NPQK, NHWC, KRSC)
|
||||
:type problem_size: :class:`cutlass.gemm.GemmCoord`)
|
||||
)pbdoc")
|
||||
.def("get_tiled_shape",
|
||||
py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv3dProblemSize&, cutlass::gemm::GemmCoord, int>(
|
||||
&T::get_tiled_shape, py::const_
|
||||
), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
|
||||
R"pbdoc(Returns the shape of the problem in units of logical tiles
|
||||
|
||||
:param problem_size: Implicit gemm problem size conv_operator(NZPQK, NDHWC, KTRSC)
|
||||
:type problem_size: :class:`cutlass.gemm.GemmCoord`)
|
||||
)pbdoc")
|
||||
// TODO: the returned dim3 is not usable in python
|
||||
.def("get_grid_shape", &T::get_grid_shape,
|
||||
py::arg("tiled_shape"),
|
||||
R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc")
|
||||
.def("tag", [](const T & swizzle){
|
||||
return boost::core::demangle(typeid(T).name());
|
||||
}, R"pbdoc(Returns the c++ name of the swizzling for code emittion)pbdoc");
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void bind_swizzle(py::module & m, std::string name, std::string doc) {
|
||||
py::class_<T>(m, name.c_str(), doc.c_str())
|
||||
.def(py::init<>())
|
||||
.def("get_tiled_shape",
|
||||
py::overload_cast<cutlass::gemm::GemmCoord, cutlass::gemm::GemmCoord, int>(
|
||||
&T::get_tiled_shape, py::const_
|
||||
), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
|
||||
R"pbdoc(Returns the shape of the problem in units of logical tiles
|
||||
|
||||
:param problem_size: gemm(M, N, K)
|
||||
:type problem_size: :class:`cutlass.gemm.GemmCoord`
|
||||
)pbdoc")
|
||||
.def("get_grid_shape", &T::get_grid_shape,
|
||||
py::arg("tiled_shape"),
|
||||
R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc")
|
||||
.def("tag", [](const T & swizzle){
|
||||
return boost::core::demangle(typeid(T).name());
|
||||
}, R"pbdoc(Returns the c++ name of the swizzling for code emittion)pbdoc");
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void bind_dgrad_swizzle(py::module & m, std::string name) {
|
||||
py::class_<T>(m, name.c_str(),
|
||||
R"pbdoc(Threadblock swizzling function for strided dgrad convolution)pbdoc")
|
||||
.def(py::init<>())
|
||||
.def("get_tiled_shape",
|
||||
py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&, cutlass::gemm::GemmCoord, int>(
|
||||
&T::get_tiled_shape, py::const_
|
||||
), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
|
||||
R"pbdoc(Returns the shape of the problem in units of logical tiles
|
||||
|
||||
:param problem_size: Implicit gemm problem size conv_operator(NPQK, NHWC, KRSC)
|
||||
:type problem_size: :class:`cutlass.gemm.GemmCoord`)
|
||||
)pbdoc")
|
||||
.def("get_grid_shape", [](const T & swizzle, cutlass::gemm::GemmCoord tiled_shape) {
|
||||
return dim3(tiled_shape.m(), tiled_shape.n(), tiled_shape.k());
|
||||
}, py::arg("tiled_shape"),
|
||||
R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc")
|
||||
.def("tag", [](const T & swizzle){
|
||||
return boost::core::demangle(typeid(T).name());
|
||||
}, R"pbdoc(Returns the c++ name of the swizzling for code emittion)pbdoc");
|
||||
}
|
||||
|
||||
void bind_threadblock_swizzle(py::module &m) {
|
||||
|
||||
py::class_<dim3>(m, "dim3",
|
||||
R"pbdoc(A int3 type xyz contains three integers)pbdoc")
|
||||
.def(py::init<int, int, int>(),
|
||||
py::arg("x"), py::arg("y"), py::arg("z"))
|
||||
.def_readwrite("x", &dim3::x, R"pbdoc(get value x)pbdoc")
|
||||
.def_readwrite("y", &dim3::y, R"pbdoc(get value y)pbdoc")
|
||||
.def_readwrite("z", &dim3::z, R"pbdoc(get value z)pbdoc");
|
||||
|
||||
bind_identity_swizzle<cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>>(m, "IdentitySwizzle1");
|
||||
bind_identity_swizzle<cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>>(m, "IdentitySwizzle2");
|
||||
bind_identity_swizzle<cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>>(m, "IdentitySwizzle4");
|
||||
bind_identity_swizzle<cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>>(m, "IdentitySwizzle8");
|
||||
|
||||
bind_swizzle<cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle>(m, "HorizontalSwizzle", R"pbdoc(Threadblock swizzling function for GEMMs)pbdoc");
|
||||
bind_swizzle<cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle>(m, "BatchedIdentitySwizzle", R"pbdoc(Threadblock swizzling function for batched GEMMs)pbdoc");
|
||||
|
||||
bind_dgrad_swizzle<cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>>(m, "StridedDgradIdentitySwizzle1");
|
||||
bind_dgrad_swizzle<cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>>(m, "StridedDgradIdentitySwizzle4");
|
||||
bind_dgrad_swizzle<cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle>(m, "StridedDgradHorizontalSwizzle");
|
||||
}
|
||||
@ -0,0 +1,72 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind Tensor Coord to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "cutlass/tensor_coord.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void bind_tensor_coord(py::module &m) {
|
||||
//
|
||||
// Tensor Coords
|
||||
// cutlass/include/cutlass/tensor_coord.h
|
||||
//
|
||||
|
||||
/// Defines a canonical 4D coordinate used by tensor operations.
|
||||
py::class_<cutlass::Tensor4DCoord>(m, "Tensor4DCoord",
|
||||
R"pbdoc(Defines a canonical 4D coordinate used by tensor operations)pbdoc")
|
||||
.def(py::init<int, int, int, int>(),
|
||||
py::arg("n"), py::arg("h"), py::arg("w"), py::arg("c"),
|
||||
R"pbdoc(Helper to construct from N, H, W, and C)pbdoc");
|
||||
|
||||
py::class_<cutlass::Coord<3>>(m, "Tensor3DCoord",
|
||||
R"pbdoc(Defines a canonical 3D coordinate used by tensor operations)pbdoc")
|
||||
.def("at", py::overload_cast<int>(&cutlass::Coord<3>::at),
|
||||
py::arg("dim"),
|
||||
R"pbdoc(Gets the index of a given Coord element)pbdoc");
|
||||
|
||||
// Matrix Size
|
||||
py::class_<cutlass::MatrixCoord>(m, "MatrixCoord",
|
||||
R"pbdoc(MatrixCoord wraps Coord<2, int> to provide a helper for accessing named dimensions. Classes
|
||||
expecting a coordinate in the rank=2 index space of a matrix should use MatrixCoord.)pbdoc")
|
||||
.def(py::init<int, int>(),
|
||||
py::arg("row"), py::arg("column"), R"pbdoc(Helper to construct from a row and column)pbdoc")
|
||||
.def("row", py::overload_cast<>(&cutlass::MatrixCoord::row),
|
||||
R"pbdoc(Returns the row of the coordinate)pbdoc")
|
||||
.def("column", py::overload_cast<>(&cutlass::MatrixCoord::column),
|
||||
R"pbdoc(Returns the column of the coordinate)pbdoc");
|
||||
|
||||
}
|
||||
@ -0,0 +1,102 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSE<cutlass::TensorRef<QUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind TensorRef and View to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "types.h"
|
||||
|
||||
|
||||
template<typename T, typename L, typename TF>
|
||||
void bind_tensor_ref_view(py::module &m, std::string name) {
|
||||
py::class_<cutlass::TensorRef<T, L>>(m, ("TensorRef" + name).c_str())
|
||||
.def("__init__", [](cutlass::TensorRef<T, L>& tensor_ref, int64_t address, const L& layout_ ) {
|
||||
T* ptr = reinterpret_cast< T*>(address);
|
||||
new (&tensor_ref) cutlass::TensorRef<T, L>(ptr, layout_);
|
||||
})
|
||||
.def("data", [](cutlass::TensorRef<T, L>& tensor_ref) {
|
||||
T* ptr = tensor_ref.data();
|
||||
return int64_t(ptr);
|
||||
})
|
||||
.def("layout", py::overload_cast<>(&cutlass::TensorRef<T, L>::layout));
|
||||
|
||||
m.def("get_tensor_ref", [](int64_t address, TF data, const L& layout_) {
|
||||
T* ptr = reinterpret_cast<T*>(address);
|
||||
cutlass::TensorRef<T, L> tensor_ref = cutlass::TensorRef<T, L>(ptr, layout_);
|
||||
return tensor_ref;
|
||||
});
|
||||
|
||||
py::class_<cutlass::TensorView<T, L>>(m, ("TensorView" + name).c_str())
|
||||
.def(py::init<const cutlass::TensorRef<T, L>&, const typename L::TensorCoord &>());
|
||||
}
|
||||
|
||||
|
||||
void bind_tensor_refs_and_views(py::module &m) {
|
||||
|
||||
/// float
|
||||
bind_tensor_ref_view<float, cutlass::layout::RowMajor, cutlass::float32>(m, "F32RowMajor");
|
||||
bind_tensor_ref_view<float, cutlass::layout::ColumnMajor, cutlass::float32>(m, "F32ColumnMajor");
|
||||
bind_tensor_ref_view<float, cutlass::layout::TensorNHWC, cutlass::float32>(m, "F32NHWC");
|
||||
|
||||
/// double
|
||||
bind_tensor_ref_view<double, cutlass::layout::RowMajor, cutlass::float64>(m, "F64RowMajor");
|
||||
bind_tensor_ref_view<double, cutlass::layout::ColumnMajor, cutlass::float64>(m, "F64ColumnMajor");
|
||||
bind_tensor_ref_view<double, cutlass::layout::TensorNHWC, cutlass::float64>(m, "F64NHWC");
|
||||
|
||||
// half_t
|
||||
bind_tensor_ref_view<cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t>(m, "F16RowMajor");
|
||||
bind_tensor_ref_view<cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t>(m, "F16ColumnMajor");
|
||||
bind_tensor_ref_view<cutlass::half_t, cutlass::layout::TensorNHWC, cutlass::half_t>(m, "F16NHWC");
|
||||
|
||||
// bfloat16
|
||||
bind_tensor_ref_view<cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t>(m, "BF16RowMajor");
|
||||
bind_tensor_ref_view<cutlass::bfloat16_t, cutlass::layout::ColumnMajor, cutlass::bfloat16_t>(m, "BF16ColumnMajor");
|
||||
bind_tensor_ref_view<cutlass::bfloat16_t, cutlass::layout::TensorNHWC, cutlass::bfloat16_t>(m, "BF16NHWC");
|
||||
|
||||
// int8_t
|
||||
bind_tensor_ref_view<int8_t, cutlass::layout::RowMajorInterleaved<32>, cutlass::int8>(m, "S8RowMajorInterleaved32");
|
||||
bind_tensor_ref_view<int8_t, cutlass::layout::ColumnMajorInterleaved<32>, cutlass::int8>(m, "S8ColumnMajorInterleaved32");
|
||||
bind_tensor_ref_view<int8_t, cutlass::layout::RowMajor, cutlass::int8>(m, "S8RowMajor");
|
||||
bind_tensor_ref_view<int8_t, cutlass::layout::ColumnMajor, cutlass::int8>(m, "S8ColumnMajor");
|
||||
bind_tensor_ref_view<int8_t, cutlass::layout::TensorNHWC, cutlass::int8>(m, "S8NHWC");
|
||||
bind_tensor_ref_view<int8_t, cutlass::layout::TensorNCxHWx<32>, cutlass::int8>(m, "S8NC32HW32");
|
||||
bind_tensor_ref_view<int8_t, cutlass::layout::TensorCxRSKx<32>, cutlass::int8>(m, "S8C32RSK32");
|
||||
|
||||
// int32_t
|
||||
bind_tensor_ref_view<int32_t, cutlass::layout::RowMajor, cutlass::int32>(m, "S32RowMajor");
|
||||
bind_tensor_ref_view<int32_t, cutlass::layout::ColumnMajor, cutlass::int32>(m, "S32ColumnMajor");
|
||||
bind_tensor_ref_view<int32_t, cutlass::layout::TensorNHWC, cutlass::int32>(m, "S32NHWC");
|
||||
}
|
||||
146
tools/library/scripts/pycutlass/src/cpp/include/types.h
Normal file
146
tools/library/scripts/pycutlass/src/cpp/include/types.h
Normal file
@ -0,0 +1,146 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind CUTLASS types to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "cutlass/half.h"
|
||||
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/// IEEE 32-bit signed integer
|
||||
struct alignas(1) int8 {
|
||||
int8_t storage;
|
||||
explicit int8(int x) {
|
||||
storage = int8_t(x);
|
||||
}
|
||||
explicit int8(float x) {
|
||||
storage = int8_t(x);
|
||||
}
|
||||
|
||||
int8_t c_value(){return storage;}
|
||||
};
|
||||
|
||||
/// IEEE 32-bit signed integer
|
||||
struct alignas(4) int32 {
|
||||
int storage;
|
||||
explicit int32(int x) {
|
||||
storage = x;
|
||||
}
|
||||
explicit int32(float x) {
|
||||
storage = int(x);
|
||||
}
|
||||
|
||||
int c_value(){return storage;}
|
||||
};
|
||||
/// IEEE single-precision floating-point type
|
||||
struct alignas(4) float32 {
|
||||
float storage;
|
||||
explicit float32(float x) {
|
||||
storage = x;
|
||||
}
|
||||
explicit float32(int x) {
|
||||
storage = float(x);
|
||||
}
|
||||
float c_value(){return storage;}
|
||||
};
|
||||
/// IEEE double-precision floating-point type
|
||||
struct alignas(4) float64 {
|
||||
double storage;
|
||||
explicit float64(float x) {
|
||||
storage = double(x);
|
||||
}
|
||||
explicit float64(int x) {
|
||||
storage = double(x);
|
||||
}
|
||||
double c_value(){return storage;}
|
||||
};
|
||||
}
|
||||
|
||||
void bind_cutlass_types(py::module &m) {
|
||||
|
||||
// s8
|
||||
py::class_<cutlass::int8>(m, "int8")
|
||||
.def(py::init<float>())
|
||||
.def(py::init<int>())
|
||||
.def_readwrite("storage", &cutlass::int8::storage)
|
||||
.def("value", &cutlass::int8::c_value);
|
||||
|
||||
// s32
|
||||
py::class_<cutlass::int32>(m, "int32")
|
||||
.def(py::init<float>())
|
||||
.def(py::init<int>())
|
||||
.def_readwrite("storage", &cutlass::int32::storage)
|
||||
.def("value", &cutlass::int32::c_value);
|
||||
|
||||
// f16
|
||||
py::class_<cutlass::half_t>(m, "float16")
|
||||
.def(py::init<float>())
|
||||
.def(py::init<double>())
|
||||
.def(py::init<int>())
|
||||
.def(py::init<unsigned>())
|
||||
.def_readwrite("storage", &cutlass::half_t::storage)
|
||||
.def("value", [](const cutlass::half_t& value) {return value;});
|
||||
|
||||
// bf16
|
||||
py::class_<cutlass::bfloat16_t>(m, "bfloat16")
|
||||
.def(py::init<float>())
|
||||
.def(py::init<int>())
|
||||
.def_readwrite("storage", &cutlass::bfloat16_t::storage)
|
||||
.def("value", [](const cutlass::bfloat16_t& value) {return value;});
|
||||
|
||||
// f32
|
||||
py::class_<cutlass::float32>(m, "float32")
|
||||
.def(py::init<float>())
|
||||
.def(py::init<int>())
|
||||
.def_readwrite("storage", &cutlass::float32::storage)
|
||||
.def("value", &cutlass::float32::c_value);
|
||||
|
||||
// tf32
|
||||
py::class_<cutlass::tfloat32_t>(m, "tfloat32")
|
||||
.def(py::init<float>())
|
||||
.def(py::init<int>())
|
||||
.def_readwrite("storage", &cutlass::tfloat32_t::storage)
|
||||
.def("value", [](const cutlass::tfloat32_t& value) {return value;});
|
||||
|
||||
// f64
|
||||
py::class_<cutlass::float64>(m, "float64")
|
||||
.def(py::init<float>())
|
||||
.def(py::init<int>())
|
||||
.def_readwrite("storage", &cutlass::float64::storage)
|
||||
.def("value", &cutlass::float64::c_value);
|
||||
}
|
||||
32
tools/library/scripts/pycutlass/src/cpp/library.h
Normal file
32
tools/library/scripts/pycutlass/src/cpp/library.h
Normal file
@ -0,0 +1,32 @@
|
||||
#include <cutlass/complex.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/// ENUM class for datatypes
|
||||
enum class DataType {
|
||||
kB1, kU2, kU4, kU8,
|
||||
kU16, kU32, kU64, kS2,
|
||||
kS4, kS8, kS16, kS32,
|
||||
kS64, kF16, kBF16, kF32,
|
||||
kTF32, kF64, kCF16, kCBF16,
|
||||
kCF32, kCTF32, kCF64, kCS2,
|
||||
kCS4, kCS8, kCS16, kCS32,
|
||||
kCS64, kCU2, kCU4, kCU8,
|
||||
kCU16, kCU32, kCU64, kInvalid
|
||||
};
|
||||
|
||||
/// ENUM class for LayoutTypes
|
||||
enum class LayoutType {
|
||||
kColumnMajor, kRowMajor,
|
||||
kColumnMajorInterleaved2, kRowMajorInterleaved2,
|
||||
kColumnMajorInterleaved32, kRowMajorInterleaved32,
|
||||
kColumnMajorInterleaved64, kRowMajorInterleaved64,
|
||||
kTensorNHWC, kTensorNDHWC, kTensorNCHW, kTensorNGHWC,
|
||||
kTensorNC32HW32, kTensorNC64HW64, kTensorC32RSK32,
|
||||
kTensorC64RSK64
|
||||
};
|
||||
|
||||
/// ENUM class for opcode class
|
||||
|
||||
|
||||
} // namespace cutlass
|
||||
@ -0,0 +1,54 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind convolution problems to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
|
||||
#include "unit/conv/device/conv2d_problems.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
PYBIND11_MAKE_OPAQUE(std::vector<cutlass::conv::Conv2dProblemSize>);
|
||||
|
||||
void bind_conv_problem_size_test(py::module &m) {
|
||||
|
||||
py::bind_vector<std::vector<cutlass::conv::Conv2dProblemSize>>(m, "Conv2dProblemVector")
|
||||
.def("size", &std::vector<cutlass::conv::Conv2dProblemSize>::size);
|
||||
// Get Conv2d problem sizes
|
||||
py::class_<test::conv::device::TestbedConv2dProblemSizes>(m, "TestbedConv2dProblemSizes")
|
||||
.def(py::init<int>())
|
||||
.def_readonly("conv2d_default_sizes", &test::conv::device::TestbedConv2dProblemSizes::conv2d_default_sizes);
|
||||
}
|
||||
@ -0,0 +1,49 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind convolution related types to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "conv_problems.h"
|
||||
#include "host.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void bind_convolution_test(py::module &m) {
|
||||
// Conv problem sizes
|
||||
bind_conv_problem_size_test(m);
|
||||
|
||||
py::module_ host_submodule = m.def_submodule("host");
|
||||
bind_conv_host_references(host_submodule);
|
||||
}
|
||||
180
tools/library/scripts/pycutlass/src/cpp/test/conv/host.h
Normal file
180
tools/library/scripts/pycutlass/src/cpp/test/conv/host.h
Normal file
@ -0,0 +1,180 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind Convolution host test helpers to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
#include "unit/conv/device/cache_testbed_output.h"
|
||||
|
||||
|
||||
#include "cutlass/util/reference/host/convolution.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
||||
template<typename Ta, typename La, typename Tb, typename Lb, typename Tc, typename Lc, typename Tacc, typename Te>
|
||||
void bind_conv2d_host(py::module &m) {
|
||||
m.def("conv2d", \
|
||||
&cutlass::reference::host::Conv2d< \
|
||||
Ta, La, Tb, Lb, Tc, Lc, Te, Tacc>);
|
||||
|
||||
m.def("CreateCachedConv2dTestKey", &test::conv::device::CreateCachedConv2dTestKey<Ta, La, Tb, Lb, Tc, Lc, Tacc, Te>);
|
||||
}
|
||||
|
||||
template<typename Ta, typename La, typename Tb, typename Lb, typename Tc, typename Lc, typename Tacc, typename Te>
|
||||
void bind_conv2d_host_sat(py::module &m) {
|
||||
m.def("conv2d", \
|
||||
&cutlass::reference::host::Conv2d< \
|
||||
Ta, La, Tb, Lb, Tc, Lc, Te, Tacc, cutlass::NumericConverterClamp<Tc, Te>>);
|
||||
|
||||
m.def("CreateCachedConv2dTestKey", &test::conv::device::CreateCachedConv2dTestKey<Ta, La, Tb, Lb, Tc, Lc, Tacc, Te>);
|
||||
}
|
||||
|
||||
template<typename Ta, typename Tb, typename Tc, typename Tacc, typename Te>
|
||||
void bind_conv2d_host_nhwc(py::module &m) {
|
||||
bind_conv2d_host<
|
||||
Ta, cutlass::layout::TensorNHWC,
|
||||
Tb, cutlass::layout::TensorNHWC,
|
||||
Tc, cutlass::layout::TensorNHWC,
|
||||
Tacc, Te>(m);
|
||||
}
|
||||
|
||||
template<typename Ta, typename Tb, typename Tc, typename Tacc, typename Te>
|
||||
void bind_conv2d_host_nc32hw32(py::module &m) {
|
||||
bind_conv2d_host_sat<
|
||||
Ta, cutlass::layout::TensorNCxHWx<32>,
|
||||
Tb, cutlass::layout::TensorCxRSKx<32>,
|
||||
Tc, cutlass::layout::TensorNCxHWx<32>,
|
||||
Tacc, Te>(m);
|
||||
}
|
||||
|
||||
|
||||
template<typename T, typename Layout>
|
||||
void bind_tensor_equals(py::module &m) {
|
||||
m.def("equals", py::overload_cast<
|
||||
const cutlass::TensorView<T, Layout>&, const cutlass::TensorView<T, Layout>&>(
|
||||
&cutlass::reference::host::TensorEquals<T, Layout>
|
||||
));
|
||||
}
|
||||
|
||||
#define BIND_TENSOR_HASH(Element, Layout) { \
|
||||
m.def("TensorHash", &test::conv::device::TensorHash<Element, Layout>, py::arg("view"), py::arg("hash") = test::conv::device::CRC32(), py::arg("crc")=uint32_t()); \
|
||||
}
|
||||
|
||||
void bind_conv_host_references(py::module &m) {
|
||||
//
|
||||
// Conv2d reference on host
|
||||
// tools/util/include/cutlass/util/reference/host/convolution.h
|
||||
|
||||
/// double
|
||||
bind_conv2d_host_nhwc<double, double, double, double, double>(m);
|
||||
/// float
|
||||
bind_conv2d_host_nhwc<float, float, float, float, float>(m);
|
||||
/// half
|
||||
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, cutlass::half_t, cutlass::half_t, cutlass::half_t>(m);
|
||||
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, cutlass::half_t, float, cutlass::half_t>(m);
|
||||
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, cutlass::half_t, float, float>(m);
|
||||
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, cutlass::half_t, cutlass::half_t, float>(m);
|
||||
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, float, cutlass::half_t, cutlass::half_t>(m);
|
||||
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, float, float, cutlass::half_t>(m);
|
||||
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, float, float, float>(m);
|
||||
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, float, cutlass::half_t, float>(m);
|
||||
/// bfloat16
|
||||
bind_conv2d_host_nhwc<cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass::bfloat16_t, float, cutlass::bfloat16_t>(m);
|
||||
bind_conv2d_host_nhwc<cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass::bfloat16_t, float, float>(m);
|
||||
bind_conv2d_host_nhwc<cutlass::bfloat16_t, cutlass::bfloat16_t, float, float, cutlass::bfloat16_t>(m);
|
||||
bind_conv2d_host_nhwc<cutlass::bfloat16_t, cutlass::bfloat16_t, float, float, float>(m);
|
||||
/// s8
|
||||
bind_conv2d_host_nhwc<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
|
||||
bind_conv2d_host_nhwc<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
|
||||
bind_conv2d_host_nhwc<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
|
||||
bind_conv2d_host_nhwc<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
|
||||
bind_conv2d_host_nhwc<int8_t, int8_t, int8_t, int32_t, float>(m);
|
||||
bind_conv2d_host_nhwc<int8_t, int8_t, int8_t, int32_t, float>(m);
|
||||
bind_conv2d_host_nhwc<int8_t, int8_t, int32_t, int32_t, float>(m);
|
||||
bind_conv2d_host_nhwc<int8_t, int8_t, int32_t, int32_t, float>(m);
|
||||
|
||||
bind_conv2d_host_nc32hw32<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
|
||||
bind_conv2d_host_nc32hw32<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
|
||||
bind_conv2d_host_nc32hw32<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
|
||||
bind_conv2d_host_nc32hw32<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
|
||||
bind_conv2d_host_nc32hw32<int8_t, int8_t, int8_t, int32_t, float>(m);
|
||||
bind_conv2d_host_nc32hw32<int8_t, int8_t, int8_t, int32_t, float>(m);
|
||||
bind_conv2d_host_nc32hw32<int8_t, int8_t, int32_t, int32_t, float>(m);
|
||||
bind_conv2d_host_nc32hw32<int8_t, int8_t, int32_t, int32_t, float>(m);
|
||||
|
||||
//
|
||||
// Compare whether two tensors are equal
|
||||
//
|
||||
/// double
|
||||
bind_tensor_equals<double, cutlass::layout::TensorNHWC>(m);
|
||||
/// float
|
||||
bind_tensor_equals<float, cutlass::layout::TensorNHWC>(m);
|
||||
/// half
|
||||
bind_tensor_equals<cutlass::half_t, cutlass::layout::TensorNHWC>(m);
|
||||
/// bfloat16
|
||||
bind_tensor_equals<cutlass::bfloat16_t, cutlass::layout::TensorNHWC>(m);
|
||||
/// s32
|
||||
bind_tensor_equals<int32_t, cutlass::layout::TensorNHWC>(m);
|
||||
bind_tensor_equals<int32_t, cutlass::layout::TensorNCxHWx<32>>(m);
|
||||
/// s8
|
||||
bind_tensor_equals<int8_t, cutlass::layout::TensorNHWC>(m);
|
||||
bind_tensor_equals<int8_t, cutlass::layout::TensorNCxHWx<32>>(m);
|
||||
|
||||
/// Cache
|
||||
py::class_<test::conv::device::CachedTestKey>(m, "CachedTestKey")
|
||||
.def(py::init<>())
|
||||
.def(py::init<std::string, std::string, std::string, uint32_t, uint32_t, uint32_t>());
|
||||
|
||||
py::class_<test::conv::device::CachedTestResult>(m, "CachedTestResult")
|
||||
.def(py::init<>())
|
||||
.def(py::init<uint32_t>())
|
||||
.def_readwrite("D", &test::conv::device::CachedTestResult::D);
|
||||
|
||||
py::class_<test::conv::device::CachedTestResultListing>(m, "CachedTestResultListing")
|
||||
.def(py::init<const std::string &>())
|
||||
.def("find", &test::conv::device::CachedTestResultListing::find)
|
||||
.def("append", &test::conv::device::CachedTestResultListing::append)
|
||||
.def("write", &test::conv::device::CachedTestResultListing::write);
|
||||
|
||||
py::class_<test::conv::device::CRC32>(m, "CRC32")
|
||||
.def(py::init<>());
|
||||
|
||||
BIND_TENSOR_HASH(double, cutlass::layout::TensorNHWC)
|
||||
BIND_TENSOR_HASH(float, cutlass::layout::TensorNHWC);
|
||||
BIND_TENSOR_HASH(cutlass::half_t, cutlass::layout::TensorNHWC);
|
||||
BIND_TENSOR_HASH(cutlass::bfloat16_t, cutlass::layout::TensorNHWC);
|
||||
BIND_TENSOR_HASH(int32_t, cutlass::layout::TensorNHWC);
|
||||
BIND_TENSOR_HASH(int8_t, cutlass::layout::TensorNCxHWx<32>);
|
||||
}
|
||||
45
tools/library/scripts/pycutlass/src/cpp/test/gemm/gemm.h
Normal file
45
tools/library/scripts/pycutlass/src/cpp/test/gemm/gemm.h
Normal file
@ -0,0 +1,45 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind gemm test to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "host.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void bind_gemm_test(py::module &m) {
|
||||
py::module_ host_submodule = m.def_submodule("host");
|
||||
bind_gemm_host_reference(host_submodule);
|
||||
}
|
||||
431
tools/library/scripts/pycutlass/src/cpp/test/gemm/host.h
Normal file
431
tools/library/scripts/pycutlass/src/cpp/test/gemm/host.h
Normal file
@ -0,0 +1,431 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind gemm test host functions to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/host_reorder.h"
|
||||
|
||||
#include "cutlass/functional.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
||||
template<
|
||||
typename ElementA, typename LayoutA,
|
||||
typename ElementB, typename LayoutB,
|
||||
typename ElementC, typename LayoutC,
|
||||
typename AccumulatorType, typename ComputeType,
|
||||
typename InnerProductOp>
|
||||
void bind_host_gemm_saturate(py::module &m) {
|
||||
m.def("gemm_saturate", py::overload_cast<
|
||||
cutlass::gemm::GemmCoord, ComputeType,
|
||||
cutlass::TensorRef<ElementA, LayoutA>,
|
||||
cutlass::TensorRef<ElementB, LayoutB>,
|
||||
ComputeType,
|
||||
cutlass::TensorRef<ElementC, LayoutC>,
|
||||
cutlass::TensorRef<ElementC, LayoutC>,
|
||||
AccumulatorType>(
|
||||
&cutlass::reference::host::compute_gemm<
|
||||
ElementA, LayoutA,
|
||||
ElementB, LayoutB,
|
||||
ElementC, LayoutC,
|
||||
ComputeType,
|
||||
AccumulatorType,
|
||||
InnerProductOp,
|
||||
cutlass::NumericConverterClamp<ElementC, AccumulatorType>>
|
||||
));
|
||||
}
|
||||
|
||||
template<
|
||||
typename ElementA, typename LayoutA,
|
||||
typename ElementB, typename LayoutB,
|
||||
typename ElementC, typename LayoutC,
|
||||
typename AccumulatorType, typename ComputeType,
|
||||
typename InnerProductOp>
|
||||
void bind_host_gemm(py::module &m) {
|
||||
m.def("gemm", py::overload_cast<
|
||||
cutlass::gemm::GemmCoord, ComputeType,
|
||||
cutlass::TensorRef<ElementA, LayoutA>,
|
||||
cutlass::TensorRef<ElementB, LayoutB>,
|
||||
ComputeType,
|
||||
cutlass::TensorRef<ElementC, LayoutC>,
|
||||
cutlass::TensorRef<ElementC, LayoutC>,
|
||||
AccumulatorType>(
|
||||
&cutlass::reference::host::compute_gemm<
|
||||
ElementA, LayoutA,
|
||||
ElementB, LayoutB,
|
||||
ElementC, LayoutC,
|
||||
ComputeType,
|
||||
AccumulatorType,
|
||||
InnerProductOp,
|
||||
cutlass::NumericConverter<ElementC, AccumulatorType>>
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
template<
|
||||
typename ElementA, typename ElementB, typename ElementC,
|
||||
typename AccumulatorType, typename ComputeType>
|
||||
void bind_host_gemm_multiply_add(py::module &m) {
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::RowMajor,
|
||||
ElementB, cutlass::layout::RowMajor,
|
||||
ElementC, cutlass::layout::RowMajor,
|
||||
ComputeType, AccumulatorType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::ColumnMajor,
|
||||
ElementB, cutlass::layout::RowMajor,
|
||||
ElementC, cutlass::layout::RowMajor,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::RowMajor,
|
||||
ElementB, cutlass::layout::ColumnMajor,
|
||||
ElementC, cutlass::layout::RowMajor,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::RowMajor,
|
||||
ElementB, cutlass::layout::RowMajor,
|
||||
ElementC, cutlass::layout::ColumnMajor,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::RowMajor,
|
||||
ElementB, cutlass::layout::ColumnMajor,
|
||||
ElementC, cutlass::layout::ColumnMajor,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::ColumnMajor,
|
||||
ElementB, cutlass::layout::RowMajor,
|
||||
ElementC, cutlass::layout::ColumnMajor,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::ColumnMajor,
|
||||
ElementB, cutlass::layout::ColumnMajor,
|
||||
ElementC, cutlass::layout::RowMajor,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::ColumnMajor,
|
||||
ElementB, cutlass::layout::ColumnMajor,
|
||||
ElementC, cutlass::layout::ColumnMajor,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
}
|
||||
|
||||
template<
|
||||
typename ElementA, typename ElementB, typename ElementC,
|
||||
typename AccumulatorType, typename ComputeType>
|
||||
void bind_host_gemm_multiply_add_saturate(py::module &m) {
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::RowMajor,
|
||||
ElementB, cutlass::layout::RowMajor,
|
||||
ElementC, cutlass::layout::RowMajor,
|
||||
ComputeType, AccumulatorType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::ColumnMajor,
|
||||
ElementB, cutlass::layout::RowMajor,
|
||||
ElementC, cutlass::layout::RowMajor,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::RowMajor,
|
||||
ElementB, cutlass::layout::ColumnMajor,
|
||||
ElementC, cutlass::layout::RowMajor,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::RowMajor,
|
||||
ElementB, cutlass::layout::RowMajor,
|
||||
ElementC, cutlass::layout::ColumnMajor,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::RowMajor,
|
||||
ElementB, cutlass::layout::ColumnMajor,
|
||||
ElementC, cutlass::layout::ColumnMajor,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::ColumnMajor,
|
||||
ElementB, cutlass::layout::RowMajor,
|
||||
ElementC, cutlass::layout::ColumnMajor,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::ColumnMajor,
|
||||
ElementB, cutlass::layout::ColumnMajor,
|
||||
ElementC, cutlass::layout::RowMajor,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::ColumnMajor,
|
||||
ElementB, cutlass::layout::ColumnMajor,
|
||||
ElementC, cutlass::layout::ColumnMajor,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
}
|
||||
|
||||
|
||||
template<
|
||||
typename ElementA, typename ElementB, typename ElementC,
|
||||
typename AccumulatorType, typename ComputeType>
|
||||
void bind_host_gemm_multiply_add_interleaved(py::module &m) {
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ComputeType, AccumulatorType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::RowMajorInterleaved<32>,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::RowMajorInterleaved<32>,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::RowMajorInterleaved<32>,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
}
|
||||
|
||||
template<
|
||||
typename ElementA, typename ElementB, typename ElementC,
|
||||
typename AccumulatorType, typename ComputeType>
|
||||
void bind_host_gemm_multiply_add_saturate_interleaved(py::module &m) {
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ComputeType, AccumulatorType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::RowMajorInterleaved<32>,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::RowMajorInterleaved<32>,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::RowMajorInterleaved<32>,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
}
|
||||
|
||||
#define BIND_TENSOR_EQUAL(Element, Layout) { \
|
||||
m.def("equals", py::overload_cast< \
|
||||
const cutlass::TensorView<Element, Layout>&, const cutlass::TensorView<Element, Layout>&>( \
|
||||
&cutlass::reference::host::TensorEquals<Element, Layout>)); \
|
||||
}
|
||||
|
||||
void bind_gemm_host_reference(py::module &m) {
|
||||
|
||||
/// double
|
||||
bind_host_gemm_multiply_add<double, double, double, double, double>(m);
|
||||
/// float
|
||||
bind_host_gemm_multiply_add<float, float, float, float, float>(m);
|
||||
/// half_t
|
||||
bind_host_gemm_multiply_add<cutlass::half_t, cutlass::half_t, cutlass::half_t, cutlass::half_t, cutlass::half_t>(m);
|
||||
bind_host_gemm_multiply_add<cutlass::half_t, cutlass::half_t, cutlass::half_t, float, float>(m);
|
||||
bind_host_gemm_multiply_add<cutlass::half_t, cutlass::half_t, float, cutlass::half_t, cutlass::half_t>(m);
|
||||
bind_host_gemm_multiply_add<cutlass::half_t, cutlass::half_t, float, float, float>(m);
|
||||
/// bfloat16
|
||||
bind_host_gemm_multiply_add<cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass::bfloat16_t, float, float>(m);
|
||||
bind_host_gemm_multiply_add<cutlass::bfloat16_t, cutlass::bfloat16_t, float, float, float>(m);
|
||||
|
||||
/// s8
|
||||
bind_host_gemm_multiply_add<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
|
||||
bind_host_gemm_multiply_add<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
|
||||
bind_host_gemm_multiply_add<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
|
||||
bind_host_gemm_multiply_add<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
|
||||
bind_host_gemm_multiply_add<int8_t, int8_t, int8_t, int32_t, float>(m);
|
||||
bind_host_gemm_multiply_add<int8_t, int8_t, int8_t, int32_t, float>(m);
|
||||
bind_host_gemm_multiply_add<int8_t, int8_t, int32_t, int32_t, float>(m);
|
||||
bind_host_gemm_multiply_add<int8_t, int8_t, int32_t, int32_t, float>(m);
|
||||
|
||||
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
|
||||
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
|
||||
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
|
||||
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
|
||||
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int8_t, int32_t, float>(m);
|
||||
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int8_t, int32_t, float>(m);
|
||||
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int32_t, int32_t, float>(m);
|
||||
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int32_t, int32_t, float>(m);
|
||||
|
||||
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
|
||||
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
|
||||
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
|
||||
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
|
||||
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int8_t, int32_t, float>(m);
|
||||
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int8_t, int32_t, float>(m);
|
||||
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int32_t, int32_t, float>(m);
|
||||
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int32_t, int32_t, float>(m);
|
||||
|
||||
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
|
||||
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
|
||||
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
|
||||
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
|
||||
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int8_t, int32_t, float>(m);
|
||||
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int8_t, int32_t, float>(m);
|
||||
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int32_t, int32_t, float>(m);
|
||||
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int32_t, int32_t, float>(m);
|
||||
|
||||
// float
|
||||
BIND_TENSOR_EQUAL(float, cutlass::layout::RowMajor);
|
||||
BIND_TENSOR_EQUAL(float, cutlass::layout::ColumnMajor);
|
||||
|
||||
// double
|
||||
BIND_TENSOR_EQUAL(double, cutlass::layout::RowMajor);
|
||||
BIND_TENSOR_EQUAL(double, cutlass::layout::ColumnMajor);
|
||||
|
||||
// half_t
|
||||
BIND_TENSOR_EQUAL(cutlass::half_t, cutlass::layout::RowMajor);
|
||||
BIND_TENSOR_EQUAL(cutlass::half_t, cutlass::layout::ColumnMajor);
|
||||
|
||||
// bfloat16
|
||||
BIND_TENSOR_EQUAL(cutlass::bfloat16_t, cutlass::layout::RowMajor);
|
||||
BIND_TENSOR_EQUAL(cutlass::bfloat16_t, cutlass::layout::ColumnMajor);
|
||||
|
||||
// int32_t
|
||||
BIND_TENSOR_EQUAL(int32_t, cutlass::layout::RowMajor);
|
||||
BIND_TENSOR_EQUAL(int32_t, cutlass::layout::ColumnMajor);
|
||||
|
||||
// int8_t
|
||||
BIND_TENSOR_EQUAL(int8_t, cutlass::layout::RowMajor);
|
||||
BIND_TENSOR_EQUAL(int8_t, cutlass::layout::ColumnMajor);
|
||||
BIND_TENSOR_EQUAL(int8_t, cutlass::layout::RowMajorInterleaved<32>);
|
||||
BIND_TENSOR_EQUAL(int8_t, cutlass::layout::ColumnMajorInterleaved<32>);
|
||||
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user