Files
cutlass/python/cutlass/cpp/include/swizzling.h
ANIKET SHIVAM 4575443d44 CUTLASS 3.2 (#1024)
* CUTLASS 3.2
2023-08-07 20:50:32 -04:00

170 lines
8.7 KiB
C++

/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind threadblock swizzling to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/conv/threadblock/threadblock_swizzle.h"
#include <cxxabi.h>
#include <cuda_runtime.h>
namespace py = pybind11;
std::string demangle(const char* mangled_name) {
std::size_t len = 0;
int status = 0;
std::unique_ptr<char> ptr(
__cxxabiv1::__cxa_demangle(mangled_name, nullptr, &len, &status));
return ptr.get();
}
template<typename T>
void bind_identity_swizzle(py::module & m, std::string name) {
py::class_<T>(m, name.c_str(),
R"pbdoc(Threadblock swizzling function for GEMMs)pbdoc")
.def(py::init<>())
.def("get_tiled_shape",
py::overload_cast<cutlass::gemm::GemmCoord, cutlass::gemm::GemmCoord, int>(
&T::get_tiled_shape, py::const_
), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
R"pbdoc(Returns the shape of the problem in units of logical tiles
:param problem_size: gemm(M, N, K)
:type problem_size: :class:`cutlass.gemm.GemmCoord`
)pbdoc")
.def("get_tiled_shape",
py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&, cutlass::gemm::GemmCoord, int>(
&T::get_tiled_shape, py::const_
), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
R"pbdoc(Returns the shape of the problem in units of logical tiles
:param problem_size: Implicit gemm problem size conv_operator(NPQK, NHWC, KRSC)
:type problem_size: :class:`cutlass.gemm.GemmCoord`)
)pbdoc")
.def("get_tiled_shape",
py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv3dProblemSize&, cutlass::gemm::GemmCoord, int>(
&T::get_tiled_shape, py::const_
), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
R"pbdoc(Returns the shape of the problem in units of logical tiles
:param problem_size: Implicit gemm problem size conv_operator(NZPQK, NDHWC, KTRSC)
:type problem_size: :class:`cutlass.gemm.GemmCoord`)
)pbdoc")
.def("get_grid_shape", &T::get_grid_shape,
py::arg("tiled_shape"),
R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc")
.def("tag", [](const T & swizzle){
return demangle(typeid(T).name());
}, R"pbdoc(Returns the c++ name of the swizzling for code emission)pbdoc");
}
template<typename T>
void bind_swizzle(py::module & m, std::string name, std::string doc) {
py::class_<T>(m, name.c_str(), doc.c_str())
.def(py::init<>())
.def("get_tiled_shape",
py::overload_cast<cutlass::gemm::GemmCoord, cutlass::gemm::GemmCoord, int>(
&T::get_tiled_shape, py::const_
), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
R"pbdoc(Returns the shape of the problem in units of logical tiles
:param problem_size: gemm(M, N, K)
:type problem_size: :class:`cutlass.gemm.GemmCoord`
)pbdoc")
.def("get_grid_shape", &T::get_grid_shape,
py::arg("tiled_shape"),
R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc")
.def("tag", [](const T & swizzle){
return demangle(typeid(T).name());
}, R"pbdoc(Returns the c++ name of the swizzling for code emission)pbdoc");
}
template<typename T>
void bind_swizzle_streamk(py::module & m, std::string name, std::string doc) {
py::class_<T>(m, name.c_str(), doc.c_str())
.def(py::init<>())
.def("tag", [](const T & swizzle){
return demangle(typeid(T).name());
}, R"pbdoc(Returns the c++ name of the swizzling for code emission)pbdoc");
}
template<typename T>
void bind_dgrad_swizzle(py::module & m, std::string name) {
py::class_<T>(m, name.c_str(),
R"pbdoc(Threadblock swizzling function for strided dgrad convolution)pbdoc")
.def(py::init<>())
.def("get_tiled_shape",
py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&, cutlass::gemm::GemmCoord, int>(
&T::get_tiled_shape, py::const_
), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
R"pbdoc(Returns the shape of the problem in units of logical tiles
:param problem_size: Implicit gemm problem size conv_operator(NPQK, NHWC, KRSC)
:type problem_size: :class:`cutlass.gemm.GemmCoord`)
)pbdoc")
.def("get_grid_shape", &T::get_grid_shape,
py::arg("tiled_shape"),
R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc")
.def("tag", [](const T & swizzle){
return demangle(typeid(T).name());
}, R"pbdoc(Returns the c++ name of the swizzling for code emission)pbdoc");
}
void bind_threadblock_swizzle(py::module &m) {
py::class_<dim3>(m, "dim3",
R"pbdoc(A int3 type xyz contains three integers)pbdoc")
.def(py::init<int, int, int>(),
py::arg("x"), py::arg("y"), py::arg("z"))
.def_readwrite("x", &dim3::x, R"pbdoc(get value x)pbdoc")
.def_readwrite("y", &dim3::y, R"pbdoc(get value y)pbdoc")
.def_readwrite("z", &dim3::z, R"pbdoc(get value z)pbdoc");
bind_identity_swizzle<cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>>(m, "IdentitySwizzle1");
bind_identity_swizzle<cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>>(m, "IdentitySwizzle2");
bind_identity_swizzle<cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>>(m, "IdentitySwizzle4");
bind_identity_swizzle<cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>>(m, "IdentitySwizzle8");
bind_swizzle<cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle>(m, "HorizontalSwizzle", R"pbdoc(Threadblock swizzling function for GEMMs)pbdoc");
bind_swizzle<cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle>(m, "BatchedIdentitySwizzle", R"pbdoc(Threadblock swizzling function for batched GEMMs)pbdoc");
bind_swizzle_streamk<cutlass::gemm::threadblock::ThreadblockSwizzleStreamK>(m, "ThreadblockSwizzleStreamK", R"pbdoc(Threadblock swizzling function using Stream K feature)pbdoc");
bind_dgrad_swizzle<cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>>(m, "StridedDgradIdentitySwizzle1");
bind_dgrad_swizzle<cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>>(m, "StridedDgradIdentitySwizzle4");
bind_dgrad_swizzle<cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle>(m, "StridedDgradHorizontalSwizzle");
}