Rename python/cutlass to python/cutlass_cppgen (#2652)
This commit is contained in:
committed by
Haicheng Wu
parent
4260d4aef9
commit
177a82e251
33
python/cutlass_cppgen/emit/__init__.py
Normal file
33
python/cutlass_cppgen/emit/__init__.py
Normal file
@ -0,0 +1,33 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass_cppgen.emit.pytorch import pytorch
|
||||
267
python/cutlass_cppgen/emit/common.py
Normal file
267
python/cutlass_cppgen/emit/common.py
Normal file
@ -0,0 +1,267 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Common utilities for emitting CUTLASS kernels
|
||||
"""
|
||||
|
||||
import cutlass_cppgen
|
||||
|
||||
# Strings used for printing information about the generation of emitted scripts
|
||||
_AUTOGEN_STR = f"This file was automatically generated by the CUTLASS {cutlass_cppgen.__version__} Python interface (https://github.com/nvidia/cutlass/python)"
|
||||
|
||||
|
||||
_CSTYLE_AUTOGEN_COMMENT = f"""// {_AUTOGEN_STR}
|
||||
"""
|
||||
|
||||
|
||||
_PYSTYLE_AUTOGEN_COMMENT = f"""# {_AUTOGEN_STR}
|
||||
"""
|
||||
|
||||
_CUTLASS_KERNEL_ARGS_2x = """
|
||||
typename DeviceKernel::Arguments arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{M, N, K}, // problem size
|
||||
1,
|
||||
{alpha, beta},
|
||||
A, B, C, D,
|
||||
0, 0, 0, 0, // batch strides
|
||||
DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda
|
||||
DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb
|
||||
DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc
|
||||
DeviceKernel::LayoutC::packed({M, N}).stride(0) // ldd
|
||||
};
|
||||
"""
|
||||
|
||||
_CUTLASS_KERNEL_ARGS_2x_STREAM_K = """
|
||||
typename DeviceKernel::Arguments arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{M, N, K}, // problem size
|
||||
1,
|
||||
{alpha, beta},
|
||||
A, B, C, D,
|
||||
0, 0, 0, 0, // batch strides
|
||||
DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda
|
||||
DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb
|
||||
DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc
|
||||
DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldd
|
||||
-1 // avail_sms
|
||||
};
|
||||
"""
|
||||
|
||||
_CUTLASS_KERNEL_RUN_GEMM_2x = """
|
||||
using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
|
||||
|
||||
cutlass::Status ${name}_kernel_run(int M, int N, int K,
|
||||
const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, DeviceKernel::ElementC* D,
|
||||
ElementCompute alpha, ElementCompute beta) {
|
||||
${args}
|
||||
size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
DeviceKernel gemm_op;
|
||||
cutlass::Status status = gemm_op.initialize(arguments,
|
||||
workspace.get(),
|
||||
nullptr); // CUDA stream
|
||||
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = gemm_op();
|
||||
return status;
|
||||
}
|
||||
"""
|
||||
|
||||
_CUTLASS_KERNEL_RUN_GEMM_3x = """
|
||||
using StrideA = typename DeviceKernel::GemmKernel::StrideA;
|
||||
using StrideB = typename DeviceKernel::GemmKernel::StrideB;
|
||||
using StrideC = typename DeviceKernel::GemmKernel::StrideC;
|
||||
using StrideD = typename DeviceKernel::GemmKernel::StrideD;
|
||||
|
||||
using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
|
||||
|
||||
cutlass::Status ${name}_kernel_run(
|
||||
int M, int N, int K, int L,
|
||||
const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, DeviceKernel::ElementC* D,
|
||||
ElementCompute alpha, ElementCompute beta, const cutlass::KernelHardwareInfo& hw_info) {
|
||||
|
||||
typename DeviceKernel::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{M, N, K, L}, // problem size
|
||||
{
|
||||
A, // ptrA
|
||||
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)), // stride A
|
||||
B, // ptrB
|
||||
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)), // stride B
|
||||
},
|
||||
{
|
||||
{alpha, beta},
|
||||
C, // ptrC
|
||||
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)), // stride C
|
||||
D, // ptrD
|
||||
cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)), // stride D
|
||||
},
|
||||
hw_info
|
||||
};
|
||||
|
||||
size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
DeviceKernel gemm_op;
|
||||
cutlass::Status status = gemm_op.run(arguments,
|
||||
workspace.get(),
|
||||
nullptr); // CUDA stream
|
||||
|
||||
return status;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
_CUTLASS_KERNEL_RUN_GROUPED_GEMM_2x = """
|
||||
using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
|
||||
|
||||
int threadblock_count = DeviceKernel::sufficient();
|
||||
|
||||
cutlass::Status ${name}_kernel_run(int problem_count, cutlass::gemm::GemmCoord* problem_sizes,
|
||||
DeviceKernel::ElementA** A, DeviceKernel::ElementB** B, DeviceKernel::ElementC** C, DeviceKernel::ElementC** D,
|
||||
int64_t* lda, int64_t* ldb, int64_t* ldc, int64_t* ldd,
|
||||
ElementCompute alpha, ElementCompute beta) {
|
||||
|
||||
typename DeviceKernel::Arguments arguments {
|
||||
problem_sizes,
|
||||
problem_count,
|
||||
threadblock_count,
|
||||
{alpha, beta},
|
||||
A, B, C, D,
|
||||
lda, ldb, ldc, ldd
|
||||
};
|
||||
|
||||
size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
DeviceKernel gemm_op;
|
||||
cutlass::Status status = gemm_op.initialize(arguments,
|
||||
workspace.get(),
|
||||
nullptr); // CUDA stream
|
||||
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = gemm_op();
|
||||
return status;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
_CUTLASS_KERNEL_RUN_CONV2D_2x = """
|
||||
|
||||
using UnderlyingKernel = typename DeviceKernel::UnderlyingKernel;
|
||||
namespace {
|
||||
using TensorRefA = typename UnderlyingKernel::TensorRefA;
|
||||
using TensorRefB = typename UnderlyingKernel::TensorRefB;
|
||||
using TensorRefC = typename UnderlyingKernel::TensorRefC;
|
||||
using ElementCompute = typename UnderlyingKernel::EpilogueOutputOp::ElementCompute;
|
||||
}
|
||||
|
||||
template<typename TensorRef, typename Element>
|
||||
TensorRef get_tensor_ref(cutlass::Tensor4DCoord tensor_coord, Element* ptr){
|
||||
cutlass::layout::TensorNHWC layout = cutlass::layout::TensorNHWC::packed(tensor_coord);
|
||||
TensorRef tensor_ref(ptr, layout);
|
||||
return tensor_ref;
|
||||
}
|
||||
|
||||
cutlass::Status ${name}_kernel_run(cutlass::conv::Conv2dProblemSize* problem_size,
|
||||
UnderlyingKernel::ElementA* A, UnderlyingKernel::ElementB* B,
|
||||
UnderlyingKernel::ElementC* C, UnderlyingKernel::ElementC* D,
|
||||
ElementCompute alpha, ElementCompute beta, std::string split_k_mode,
|
||||
cudaStream_t stream, int device_id=0) {
|
||||
// create the tensor references
|
||||
cutlass::Tensor4DCoord tensor_coord_A = cutlass::conv::implicit_gemm_tensor_a_extent(
|
||||
cutlass::conv::Operator::k${conv_kind_name}, *problem_size
|
||||
);
|
||||
cutlass::Tensor4DCoord tensor_coord_B = cutlass::conv::implicit_gemm_tensor_b_extent(
|
||||
cutlass::conv::Operator::k${conv_kind_name}, *problem_size
|
||||
);
|
||||
cutlass::Tensor4DCoord tensor_coord_C = cutlass::conv::implicit_gemm_tensor_c_extent(
|
||||
cutlass::conv::Operator::k${conv_kind_name}, *problem_size
|
||||
);
|
||||
|
||||
TensorRefA tensor_ref_A = get_tensor_ref<TensorRefA, UnderlyingKernel::ElementA>(tensor_coord_A, A);
|
||||
TensorRefB tensor_ref_B = get_tensor_ref<TensorRefB, UnderlyingKernel::ElementB>(tensor_coord_B, B);
|
||||
TensorRefC tensor_ref_C = get_tensor_ref<TensorRefC, UnderlyingKernel::ElementC>(tensor_coord_C, C);
|
||||
TensorRefC tensor_ref_D = get_tensor_ref<TensorRefC, UnderlyingKernel::ElementC>(tensor_coord_C, D);
|
||||
|
||||
cutlass::conv::SplitKMode mode;
|
||||
if (split_k_mode == "serial") {
|
||||
mode = cutlass::conv::SplitKMode::kSerial;
|
||||
} else if (split_k_mode == "parallel") {
|
||||
mode = cutlass::conv::SplitKMode::kParallel;
|
||||
} else {
|
||||
throw std::runtime_error("Invalid split_k_mode: " + split_k_mode);
|
||||
}
|
||||
|
||||
typename DeviceKernel::Arguments arguments{
|
||||
*problem_size,
|
||||
tensor_ref_A,
|
||||
tensor_ref_B,
|
||||
tensor_ref_C,
|
||||
tensor_ref_D,
|
||||
{alpha, beta},
|
||||
mode
|
||||
};
|
||||
|
||||
DeviceKernel implicit_gemm_op;
|
||||
|
||||
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
|
||||
|
||||
void* workspace_ptr = device_memory_allocation(workspace_size, device_id);
|
||||
|
||||
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = implicit_gemm_op.initialize(arguments, workspace_ptr, stream);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
//
|
||||
// Launch initialized CUTLASS kernel
|
||||
//
|
||||
status = implicit_gemm_op(stream);
|
||||
|
||||
return status;
|
||||
}
|
||||
"""
|
||||
936
python/cutlass_cppgen/emit/pytorch.py
Normal file
936
python/cutlass_cppgen/emit/pytorch.py
Normal file
@ -0,0 +1,936 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Utilities for generating source for building a PyTorch CUDA extension that using a CUTLASS kernel.
|
||||
If specified, the extension can be JIT compiled via PyTorch's ``cpp_extension.load`` method.
|
||||
|
||||
Example usage with JIT compilation:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_library.LayoutType.RowMajor)
|
||||
op = plan.construct()
|
||||
mod = cutlass_cppgen.emit.pytorch(op, 'cutlass_gemm', 80, jit=True)
|
||||
|
||||
# Generate inputs for the GEMM
|
||||
A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)]
|
||||
|
||||
# Run the module
|
||||
D = mod.run(A, B, C)
|
||||
|
||||
|
||||
Example usage without JIT compilation:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
op = plan.construct()
|
||||
cutlass_cppgen.emit.pytorch(op, 'cutlass_gemm', 80, jit=False, sourcedir='output')
|
||||
|
||||
After this call, the directory ``output`` contains ``setup.py``,
|
||||
``cutlass_gemm.cpp``, and ``cutlass_gemm_kernel.cu``. The module can be built from
|
||||
within ``output`` by running: ``TORCH_CUDA_ARCH_LIST="8.0" python setup.py develop --user``.
|
||||
|
||||
The module can later be used in Python via:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
import cutlass_gemm
|
||||
|
||||
# Generate inputs for the GEMM
|
||||
A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)]
|
||||
|
||||
# Run the module
|
||||
D = cutlass_gemm.run(A, B, C)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from cutlass_library import ConvKind, ConvKindNames, DataType, SubstituteTemplate
|
||||
|
||||
from cutlass_cppgen import CUTLASS_PATH, logger, swizzle
|
||||
from cutlass_cppgen.backend.gemm_operation import GemmOperationGrouped, GemmOperationUniversal
|
||||
from cutlass_cppgen.backend.conv2d_operation import Conv2dOperation
|
||||
from cutlass_cppgen.backend.library import ApiVersion
|
||||
from cutlass_cppgen.emit import common
|
||||
from cutlass_cppgen.utils.datatypes import is_torch_available
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
_PYTORCH_CUDA_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/util/device_memory.h"
|
||||
|
||||
// helper function allocating the memory
|
||||
void* device_memory_allocation(size_t size, int device_id=0) {
|
||||
if (size > 0) {
|
||||
torch::Device device(torch::kCUDA, device_id);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
torch::TensorOptions options = torch::TensorOptions().dtype(torch::kI8).device(device);
|
||||
at::Tensor device_tensor = torch::empty({(long)size,}, options);
|
||||
return reinterpret_cast<void*>(device_tensor.data_ptr());
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
${includes}
|
||||
${declaration}
|
||||
${impl}
|
||||
"""
|
||||
|
||||
_PYTORCH_GEMM_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
// CUDA forward declarations
|
||||
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, float alpha=1.f, float beta=0.f);
|
||||
|
||||
// C++ interface
|
||||
at::Tensor ${name}(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, float alpha=1.f, float beta=0.f) {
|
||||
return ${name}_kernel(A, B, C, alpha, beta);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("run", py::overload_cast<const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>, float, float>(&${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f);
|
||||
}
|
||||
"""
|
||||
|
||||
_PYTORCH_GROUPED_GEMM_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
// CUDA forward declarations
|
||||
std::vector<at::Tensor> ${name}_kernel(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C=at::nullopt, float alpha=1.f, float beta=0.f);
|
||||
|
||||
// C++ interface
|
||||
std::vector<at::Tensor> ${name}(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C=at::nullopt, float alpha=1.f, float beta=0.f) {
|
||||
return ${name}_kernel(A, B, C, alpha, beta);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("run", py::overload_cast<const std::vector<at::Tensor>&, const std::vector<at::Tensor>&, at::optional<const std::vector<at::Tensor>>, float, float>(&${name}),
|
||||
py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f);
|
||||
}
|
||||
"""
|
||||
|
||||
_PYTORCH_CONV2D_FPROP_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
// CUDA forward declarations
|
||||
at::Tensor ${name}_kernel(
|
||||
const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
||||
float alpha=1.f, float beta=0.f,
|
||||
std::string split_k_mode="serial", int split_k_slices=1);
|
||||
|
||||
// C++ interface
|
||||
at::Tensor ${name}(
|
||||
const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
||||
float alpha=1.f, float beta=0.f,
|
||||
std::string split_k_mode="serial", int split_k_slices=1) {
|
||||
return ${name}_kernel(A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("run",
|
||||
py::overload_cast<
|
||||
const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>,
|
||||
std::tuple<int, int>, std::tuple<int, int>, std::tuple<int, int>, float, float, std::string, int>(
|
||||
&${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr,
|
||||
py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1),
|
||||
py::arg("alpha") = 1.f, py::arg("beta") = 0.f,
|
||||
py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1);
|
||||
}
|
||||
"""
|
||||
|
||||
_PYTORCH_CONV2D_GRAD_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
// CUDA forward declarations
|
||||
at::Tensor ${name}_kernel(
|
||||
std::tuple<int, int, int, int> result_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
||||
float alpha=1.f, float beta=0.f,
|
||||
std::string split_k_mode="serial", int split_k_slices=1);
|
||||
|
||||
// C++ interface
|
||||
at::Tensor ${name}(
|
||||
std::tuple<int, int, int, int> result_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
||||
float alpha=1.f, float beta=0.f,
|
||||
std::string split_k_mode="serial", int split_k_slices=1) {
|
||||
return ${name}_kernel(result_size, A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("run",
|
||||
py::overload_cast<
|
||||
std::tuple<int, int, int, int>, const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>,
|
||||
std::tuple<int, int>, std::tuple<int, int>, std::tuple<int, int>, float, float, std::string, int>(
|
||||
&${name}), py::arg("result_size"), py::arg("A"), py::arg("B"), py::arg("C") = nullptr,
|
||||
py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1),
|
||||
py::arg("alpha") = 1.f, py::arg("beta") = 0.f,
|
||||
py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1);
|
||||
}
|
||||
"""
|
||||
|
||||
_PYTORCH_GEMM_INCLUDES = {
|
||||
ApiVersion.v2x: """
|
||||
#include "cutlass/gemm/device/gemm_universal.h"
|
||||
""",
|
||||
ApiVersion.v3x: """
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
""",
|
||||
}
|
||||
|
||||
_PYTORCH_GROUPED_GEMM_INCLUDES = """
|
||||
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
||||
#include "cutlass/gemm/device/gemm_grouped.h"
|
||||
"""
|
||||
|
||||
_PYTORCH_CONV2D_INCLUDES = """
|
||||
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d_dgrad.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d_wgrad.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
"""
|
||||
|
||||
_CUTLASS_TYPE_TO_TORCH_TYPE = {
|
||||
DataType.f16: "torch::kF16",
|
||||
DataType.f32: "torch::kF32",
|
||||
DataType.f64: "torch::kF64",
|
||||
DataType.s8: "torch::kI8",
|
||||
DataType.s32: "torch::kI32",
|
||||
DataType.bf16: "torch::kBFloat16",
|
||||
}
|
||||
|
||||
_PYTORCH_GEMM_IMPL_TEMPLATE_2x = (
|
||||
common._CUTLASS_KERNEL_RUN_GEMM_2x
|
||||
+ """
|
||||
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C, float alpha, float beta) {
|
||||
int M = A.size(0);
|
||||
int N = B.size(1);
|
||||
int K = A.size(1);
|
||||
|
||||
typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ?
|
||||
nullptr :
|
||||
reinterpret_cast<typename DeviceKernel::ElementC*>(C->contiguous().data_ptr());
|
||||
at::Tensor D = B.new_empty({M, N}, ${torch_type_C});
|
||||
|
||||
cutlass::Status status = ${name}_kernel_run(M, N, K,
|
||||
reinterpret_cast<typename DeviceKernel::ElementA*>(A.contiguous().data_ptr()),
|
||||
reinterpret_cast<typename DeviceKernel::ElementB*>(B.contiguous().data_ptr()),
|
||||
ptrC,
|
||||
reinterpret_cast<typename DeviceKernel::ElementC*>(D.contiguous().data_ptr()),
|
||||
ElementCompute(alpha), ElementCompute(beta));
|
||||
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
|
||||
return D;
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
_PYTORCH_GEMM_IMPL_TEMPLATE_3x = (
|
||||
common._CUTLASS_KERNEL_RUN_GEMM_3x
|
||||
+ """
|
||||
bool hw_info_queried = false;
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
|
||||
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C, float alpha, float beta) {
|
||||
int M = A.size(0);
|
||||
int N = B.size(1);
|
||||
int K = A.size(1);
|
||||
int L = 1;
|
||||
|
||||
// Query hardware info if we haven't already
|
||||
if (!hw_info_queried) {
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
}
|
||||
|
||||
typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ?
|
||||
nullptr :
|
||||
reinterpret_cast<typename DeviceKernel::ElementC*>(C->contiguous().data_ptr());
|
||||
at::Tensor D = B.new_empty({M, N}, ${torch_type_C});
|
||||
|
||||
cutlass::Status status = ${name}_kernel_run(M, N, K, L,
|
||||
reinterpret_cast<typename DeviceKernel::ElementA*>(A.contiguous().data_ptr()),
|
||||
reinterpret_cast<typename DeviceKernel::ElementB*>(B.contiguous().data_ptr()),
|
||||
ptrC,
|
||||
reinterpret_cast<typename DeviceKernel::ElementC*>(D.contiguous().data_ptr()),
|
||||
ElementCompute(alpha), ElementCompute(beta),
|
||||
hw_info);
|
||||
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
|
||||
return D;
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
_PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE = (
|
||||
common._CUTLASS_KERNEL_RUN_GROUPED_GEMM_2x
|
||||
+ """
|
||||
std::vector<at::Tensor> ${name}_kernel(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C, float alpha, float beta) {
|
||||
size_t num = A.size();
|
||||
|
||||
// To avoid performing many small cudaMallocs and host-to-device copies,
|
||||
// we serialize the grouped GEMM arguments on the host, allocate one
|
||||
// large chunk of device memory, and perform a single cudaMemcpy to
|
||||
// copy the host data to the device. Allocation overheads could be
|
||||
// avoided by using a memory pool.
|
||||
|
||||
// Calculate the total size of the data to be copied from host to device
|
||||
size_t total_size = sizeof(cutlass::gemm::GemmCoord) +
|
||||
sizeof(DeviceKernel::ElementA*) +
|
||||
sizeof(DeviceKernel::ElementB*) +
|
||||
sizeof(DeviceKernel::ElementC*) +
|
||||
sizeof(DeviceKernel::ElementC*) +
|
||||
sizeof(int64_t) +
|
||||
sizeof(int64_t) +
|
||||
sizeof(int64_t);
|
||||
total_size *= num;
|
||||
|
||||
// num * sizeof(cutlass::gemm::GemmCoord) may leave one at a non-multiple
|
||||
// of sizeof(DeviceKernel::ElementA*) (which will be 64 on a 64-bit system).
|
||||
// To ensure that we don't end up having misaligned loads in the kernel,
|
||||
// we pad to the nearest multiple of 8.
|
||||
//
|
||||
// Note that, even on a 32-bit system (for which sizeof(X*) will not equal
|
||||
// sizeof(int64_t)), only padding between the list of GemmCoords and the
|
||||
// list of ptr_As is sufficient because the set of four equal-length lists of pointers
|
||||
// (A*, B*, C*, D*) will ensure that the first list of int64_ts will always
|
||||
// start on a multiple of 8.
|
||||
int64_t padding = 8 - (total_size % 8);
|
||||
total_size += padding;
|
||||
|
||||
uint8_t* host_data = new uint8_t[total_size];
|
||||
cutlass::DeviceAllocation<uint8_t> device_data(total_size);
|
||||
|
||||
uint8_t* start = host_data;
|
||||
cutlass::gemm::GemmCoord* problem_sizes_host = reinterpret_cast<cutlass::gemm::GemmCoord*>(start);
|
||||
|
||||
// Apply the padding after the list of GemmCoords
|
||||
start += num * sizeof(cutlass::gemm::GemmCoord) + padding;
|
||||
|
||||
int64_t ptr_A_offset = start - host_data;
|
||||
DeviceKernel::ElementA** ptr_A_host = reinterpret_cast<DeviceKernel::ElementA**>(start);
|
||||
start += num * sizeof(DeviceKernel::ElementA*);
|
||||
|
||||
int64_t ptr_B_offset = start - host_data;
|
||||
DeviceKernel::ElementB** ptr_B_host = reinterpret_cast<DeviceKernel::ElementB**>(start);
|
||||
start += num * sizeof(DeviceKernel::ElementB*);
|
||||
|
||||
int64_t ptr_C_offset = start - host_data;
|
||||
DeviceKernel::ElementC** ptr_C_host = reinterpret_cast<DeviceKernel::ElementC**>(start);
|
||||
start += num * sizeof(DeviceKernel::ElementC*);
|
||||
|
||||
int64_t ptr_D_offset = start - host_data;
|
||||
DeviceKernel::ElementC** ptr_D_host = reinterpret_cast<DeviceKernel::ElementC**>(start);
|
||||
start += num * sizeof(DeviceKernel::ElementC*);
|
||||
|
||||
int64_t lda_offset = start - host_data;
|
||||
int64_t* lda_host = reinterpret_cast<int64_t*>(start);
|
||||
start += num * sizeof(int64_t);
|
||||
|
||||
int64_t ldb_offset = start - host_data;
|
||||
int64_t* ldb_host = reinterpret_cast<int64_t*>(start);
|
||||
start += num * sizeof(int64_t);
|
||||
|
||||
int64_t ldc_offset = start - host_data;
|
||||
int64_t* ldc_host = reinterpret_cast<int64_t*>(start);
|
||||
start += num * sizeof(int64_t);
|
||||
|
||||
std::vector<at::Tensor> D(num);
|
||||
|
||||
bool need_C = (C != at::nullopt) && (beta != 0.f);
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
int M = A[i].size(0);
|
||||
int N = B[i].size(1);
|
||||
int K = A[i].size(1);
|
||||
*(problem_sizes_host + i) = {M, N, K};
|
||||
*(ptr_A_host + i) = reinterpret_cast<typename DeviceKernel::ElementA*>(A[i].contiguous().data_ptr());
|
||||
*(ptr_B_host + i) = reinterpret_cast<typename DeviceKernel::ElementB*>(B[i].contiguous().data_ptr());
|
||||
|
||||
if (need_C) {
|
||||
*(ptr_C_host + i) = reinterpret_cast<typename DeviceKernel::ElementC*>(C->at(i).contiguous().data_ptr());
|
||||
}
|
||||
else {
|
||||
*(ptr_C_host + i) = nullptr;
|
||||
}
|
||||
|
||||
D[i] = B[i].new_empty({M, N}, ${torch_type_C});
|
||||
*(ptr_D_host + i) = reinterpret_cast<typename DeviceKernel::ElementC*>(D[i].contiguous().data_ptr());
|
||||
|
||||
*(lda_host + i) = DeviceKernel::LayoutA::packed({M, K}).stride(0);
|
||||
*(ldb_host + i) = DeviceKernel::LayoutB::packed({K, N}).stride(0);
|
||||
*(ldc_host + i) = DeviceKernel::LayoutC::packed({M, N}).stride(0);
|
||||
}
|
||||
|
||||
device_data.copy_from_host(host_data);
|
||||
|
||||
cutlass::Status status = ${name}_kernel_run(
|
||||
num,
|
||||
reinterpret_cast<cutlass::gemm::GemmCoord*>(device_data.get()),
|
||||
reinterpret_cast<DeviceKernel::ElementA**>(device_data.get() + ptr_A_offset),
|
||||
reinterpret_cast<DeviceKernel::ElementB**>(device_data.get() + ptr_B_offset),
|
||||
reinterpret_cast<DeviceKernel::ElementC**>(device_data.get() + ptr_C_offset),
|
||||
reinterpret_cast<DeviceKernel::ElementC**>(device_data.get() + ptr_D_offset),
|
||||
reinterpret_cast<int64_t*>(device_data.get() + lda_offset),
|
||||
reinterpret_cast<int64_t*>(device_data.get() + ldb_offset),
|
||||
reinterpret_cast<int64_t*>(device_data.get() + ldc_offset),
|
||||
reinterpret_cast<int64_t*>(device_data.get() + ldc_offset),
|
||||
ElementCompute(alpha), ElementCompute(beta));
|
||||
|
||||
delete[] host_data;
|
||||
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
|
||||
return D;
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
_PYTORCH_CONV2D_IMPL_TEMPLATE_2x = """
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
cutlass::Status status = ${name}_kernel_run(
|
||||
&problem_size,
|
||||
reinterpret_cast<typename UnderlyingKernel::ElementA*>(A.data_ptr()),
|
||||
reinterpret_cast<typename UnderlyingKernel::ElementB*>(B.data_ptr()),
|
||||
ptrC,
|
||||
reinterpret_cast<typename UnderlyingKernel::ElementC*>(D.data_ptr()),
|
||||
alpha, beta,
|
||||
split_k_mode, stream, B.device().index());
|
||||
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
|
||||
return D;
|
||||
}
|
||||
"""
|
||||
|
||||
_PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x = (
|
||||
common._CUTLASS_KERNEL_RUN_CONV2D_2x
|
||||
+ """
|
||||
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
||||
float alpha=1.f, float beta=0.f, std::string split_k_mode="serial", int split_k_slices=1) {
|
||||
int N, H, W, C_, K, R, S, P, Q;
|
||||
N = A.size(0);
|
||||
C_ = A.size(1);
|
||||
H = A.size(2);
|
||||
W = A.size(3);
|
||||
|
||||
K = B.size(0);
|
||||
R = B.size(2);
|
||||
S = B.size(3);
|
||||
|
||||
cutlass::conv::Conv2dProblemSize problem_size(
|
||||
cutlass::Tensor4DCoord(N, H, W, C_),
|
||||
cutlass::Tensor4DCoord(K, R, S, C_),
|
||||
cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
|
||||
cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
|
||||
cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
|
||||
cutlass::conv::Mode::kCrossCorrelation,
|
||||
split_k_slices
|
||||
);
|
||||
|
||||
P = problem_size.P;
|
||||
Q = problem_size.Q;
|
||||
|
||||
typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
|
||||
nullptr :
|
||||
reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
|
||||
|
||||
torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
|
||||
at::Tensor D = torch::zeros({N, K, P, Q}, options);
|
||||
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
|
||||
)
|
||||
|
||||
|
||||
_PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x = (
|
||||
common._CUTLASS_KERNEL_RUN_CONV2D_2x
|
||||
+ """
|
||||
at::Tensor ${name}_kernel(std::tuple<int, int, int, int> input_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1}, float alpha=1.f, float beta=0.f,
|
||||
std::string split_k_mode="serial", int split_k_slices=1) {
|
||||
int N, H, W, C_, K, R, S;
|
||||
N = std::get<0>(input_size);
|
||||
C_ = std::get<1>(input_size);
|
||||
H = std::get<2>(input_size);
|
||||
W = std::get<3>(input_size);
|
||||
|
||||
K = B.size(0);
|
||||
R = B.size(2);
|
||||
S = B.size(3);
|
||||
|
||||
cutlass::conv::Conv2dProblemSize problem_size(
|
||||
cutlass::Tensor4DCoord(N, H, W, C_),
|
||||
cutlass::Tensor4DCoord(K, R, S, C_),
|
||||
cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
|
||||
cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
|
||||
cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
|
||||
cutlass::conv::Mode::kCrossCorrelation,
|
||||
split_k_slices
|
||||
);
|
||||
|
||||
typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
|
||||
nullptr :
|
||||
reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
|
||||
|
||||
torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
|
||||
at::Tensor D = torch::empty({N, C_, H, W}, options);
|
||||
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
|
||||
)
|
||||
|
||||
|
||||
_PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x = (
|
||||
common._CUTLASS_KERNEL_RUN_CONV2D_2x
|
||||
+ """
|
||||
at::Tensor ${name}_kernel(std::tuple<int, int, int, int> weight_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1}, float alpha=1.f, float beta=0.f,
|
||||
std::string split_k_mode="serial", int split_k_slices=1) {
|
||||
int N, H, W, C_, K, R, S;
|
||||
K = std::get<0>(weight_size);
|
||||
C_ = std::get<1>(weight_size);
|
||||
R = std::get<2>(weight_size);
|
||||
S = std::get<3>(weight_size);
|
||||
|
||||
N = B.size(0);
|
||||
H = B.size(2);
|
||||
W = B.size(3);
|
||||
|
||||
cutlass::conv::Conv2dProblemSize problem_size(
|
||||
cutlass::Tensor4DCoord(N, H, W, C_),
|
||||
cutlass::Tensor4DCoord(K, R, S, C_),
|
||||
cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
|
||||
cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
|
||||
cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
|
||||
cutlass::conv::Mode::kCrossCorrelation,
|
||||
split_k_slices
|
||||
);
|
||||
|
||||
typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
|
||||
nullptr :
|
||||
reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
|
||||
|
||||
torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
|
||||
at::Tensor D = torch::empty({K, C_, R, S}, options);
|
||||
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
|
||||
)
|
||||
|
||||
|
||||
_PYTORCH_SETUP_PY = common._PYSTYLE_AUTOGEN_COMMENT + """
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
setup(
|
||||
name='${name}',
|
||||
ext_modules=[
|
||||
CUDAExtension('${name}', [
|
||||
'${name}.cpp',
|
||||
'${name}_kernel.cu',
|
||||
],
|
||||
include_dirs=['${cutlass_path}/include', '${cutlass_path}/tools/util/include'],
|
||||
extra_compile_args={
|
||||
'cxx': ['-std=c++17'],
|
||||
'nvcc': ['-std=c++17', ${extra_compile_args}],
|
||||
},
|
||||
libraries=['cuda']
|
||||
),
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension
|
||||
})
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def _generate_setup(name: str, sourcedir: str, extra_compile_args: str=""):
|
||||
"""
|
||||
Generates a setup.py file for the extension
|
||||
|
||||
:param name: name of the module to generate
|
||||
:type name: str
|
||||
:param sourcedir: directory to which generated source files should be written
|
||||
:type sourcedir: str
|
||||
:param extra_compile_args: additional arguments to pass to setup.py
|
||||
:type extra_args: str
|
||||
"""
|
||||
setup_py_file = os.path.join(sourcedir, "setup.py")
|
||||
setup_source = SubstituteTemplate(
|
||||
_PYTORCH_SETUP_PY, {"name": name, "cutlass_path": CUTLASS_PATH, "extra_compile_args": extra_compile_args}
|
||||
)
|
||||
with open(setup_py_file, "w") as outfile:
|
||||
outfile.write(setup_source)
|
||||
|
||||
|
||||
class _ArchListSetter:
|
||||
"""
|
||||
Utility context manager for temporarily setting the value of the ``TORCH_CUDA_ARCH_LIST``
|
||||
environment variable when building a PyTorch CUDA module.
|
||||
|
||||
``TORCH_CUDA_ARCH_LIST`` is a space-delmited list of compute capabilites for which a PyTorch
|
||||
CUDA module should be compiled.
|
||||
|
||||
For example, ``TORCH_CUDA_ARCH_LIST="7.0 8.0"`` would result in the inclusion of
|
||||
``-gencode=arch=compute_70,code=sm_70`` and ``-gencode=arch=compute_80,code=sm_80`` in the
|
||||
compilation of the module.
|
||||
|
||||
This utility wraps the building of a PyTorch CUDA module with a setting of this environment
|
||||
variable according to the current compute capability being targetted.
|
||||
|
||||
Example usage:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
# Temporarily set TORCH_CUDA_ARCH_LIST="8.0"
|
||||
with _ArchListSetter(80):
|
||||
# Perform JIT compilation and loading of the module
|
||||
mod = torch.utils.cpp_extension.load(...)
|
||||
|
||||
:param cc: compute capability
|
||||
:type cc: int
|
||||
"""
|
||||
|
||||
_TORCH_CUDA_ARCH_LIST = "TORCH_CUDA_ARCH_LIST"
|
||||
|
||||
def __init__(self, cc: int):
|
||||
self.cc_str = ".".join(list(str(cc)))
|
||||
|
||||
def __enter__(self):
|
||||
"""
|
||||
Saves the old value of TORCH_CUDA_ARCH_LIST and reset it to the new value based on ``cc``
|
||||
"""
|
||||
self.old_arch_list = os.getenv(_ArchListSetter._TORCH_CUDA_ARCH_LIST)
|
||||
os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.cc_str
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, traceback):
|
||||
"""
|
||||
Restores the old value of TORCH_CUDA_ARCH_LIST
|
||||
"""
|
||||
if self.old_arch_list is None:
|
||||
del os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST]
|
||||
else:
|
||||
os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.old_arch_list
|
||||
|
||||
|
||||
def _jit(name: str, cc: int, cpp_file: str, cuda_file: str):
|
||||
"""
|
||||
JIT compiles and loads a PyTorch CUDA extension.
|
||||
|
||||
:param name: name of the module to generate
|
||||
:type name: str
|
||||
:param cc: compute capability of the device the module should target
|
||||
:type cc: int
|
||||
:param cpp_file: path to file containing extension's C++ interface
|
||||
:type cpp_file: str
|
||||
:param cuda_file: path to file containing extension's CUDA interface
|
||||
:type cuda_file: str
|
||||
|
||||
:return: loaded PyTorch module
|
||||
"""
|
||||
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
extra_cuda_cflags = ["-std=c++17"]
|
||||
if cc in [90, 100, 101, 103]:
|
||||
# PyTorch does not currently add the sm_90a target when compute capability
|
||||
# 9.0 is set within TORCH_CUDA_ARCH_LIST. Thus, we manually add the sm_90a target.
|
||||
extra_cuda_cflags.append(f"-gencode=arch=compute_{cc}a,code=sm_{cc}a")
|
||||
|
||||
with _ArchListSetter(cc):
|
||||
jitmodule = load(
|
||||
name,
|
||||
[cpp_file, cuda_file],
|
||||
extra_cuda_cflags=extra_cuda_cflags,
|
||||
extra_include_paths=[
|
||||
os.path.join(CUTLASS_PATH, "include"),
|
||||
os.path.join(CUTLASS_PATH, "tools/util/include"),
|
||||
],
|
||||
extra_ldflags=["-lcuda"],
|
||||
verbose=(logger.level == logging.DEBUG)
|
||||
)
|
||||
return jitmodule
|
||||
|
||||
|
||||
def _pytorch_gemm(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
|
||||
"""
|
||||
Generates source for building a PyTorch CUDA module that leverages the CUTLASS GEMM
|
||||
specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
|
||||
compiled, loaded, and returned.
|
||||
|
||||
:param op: operation to emit in the module
|
||||
:param name: name of the module to generate
|
||||
:type name: str
|
||||
:param cc: compute capability of the device the module should target
|
||||
:type cc: int
|
||||
:param jit: whether the module should be just-in-time compiled
|
||||
:type jit: bool
|
||||
:param sourcedir: directory to which generated source files should be written
|
||||
:type sourcedir: str
|
||||
|
||||
:return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
|
||||
"""
|
||||
if sourcedir != "" and not os.path.isdir(sourcedir):
|
||||
os.makedirs(sourcedir)
|
||||
|
||||
cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
|
||||
extra_kw = {}
|
||||
if op.api == ApiVersion.v3x:
|
||||
impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_3x
|
||||
else:
|
||||
impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_2x
|
||||
if op.swizzling_functor == swizzle.ThreadblockSwizzleStreamK:
|
||||
extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x_STREAM_K
|
||||
else:
|
||||
extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x
|
||||
impl_template = (
|
||||
_PYTORCH_GEMM_IMPL_TEMPLATE_3x
|
||||
if op.api == ApiVersion.v3x
|
||||
else _PYTORCH_GEMM_IMPL_TEMPLATE_2x
|
||||
)
|
||||
cuda_impl = SubstituteTemplate(impl_template, {"name": name, **extra_kw})
|
||||
cuda_source = SubstituteTemplate(
|
||||
_PYTORCH_CUDA_TEMPLATE,
|
||||
{
|
||||
"includes": _PYTORCH_GEMM_INCLUDES[op.api],
|
||||
"declaration": op.rt_module.emit(),
|
||||
"procedural_name": op.procedural_name(),
|
||||
"impl": cuda_impl,
|
||||
"torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element],
|
||||
},
|
||||
)
|
||||
with open(cuda_file, "w") as outfile:
|
||||
outfile.write(cuda_source)
|
||||
|
||||
cpp_file = os.path.join(sourcedir, name + ".cpp")
|
||||
cpp_source = SubstituteTemplate(
|
||||
_PYTORCH_GEMM_CPP_TEMPLATE,
|
||||
{"name": name, "description": f"CUTLASS {op.procedural_name()} GEMM"},
|
||||
)
|
||||
with open(cpp_file, "w") as outfile:
|
||||
outfile.write(cpp_source)
|
||||
|
||||
extra_compile_args = ""
|
||||
if cc in [90, 100, 101, 103]:
|
||||
extra_compile_args = f"'--generate-code=arch=compute_{cc}a,code=[sm_{cc}a]'"
|
||||
_generate_setup(name, sourcedir, extra_compile_args)
|
||||
|
||||
if jit:
|
||||
return _jit(name, cc, cpp_file, cuda_file)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _pytorch_grouped_gemm(
|
||||
op, name: str, cc: int, jit: bool = False, sourcedir: str = ""
|
||||
):
|
||||
"""
|
||||
Generates source for building a PyTorch CUDA module that leverages the CUTLASS grouped GEMM
|
||||
specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
|
||||
compiled, loaded, and returned.
|
||||
|
||||
:param op: operation to emit in the module
|
||||
:param name: name of the module to generate
|
||||
:type name: str
|
||||
:param cc: compute capability of the device the module should target
|
||||
:type cc: int
|
||||
:param jit: whether the module should be just-in-time compiled
|
||||
:type jit: bool
|
||||
:param sourcedir: directory to which generated source files should be written
|
||||
:type sourcedir: str
|
||||
|
||||
:return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
|
||||
"""
|
||||
if op.api != ApiVersion.v2x:
|
||||
raise Exception("Grouped GEMM is currently only supported for CUTLASS 2.x")
|
||||
|
||||
if sourcedir != "" and not os.path.isdir(sourcedir):
|
||||
os.makedirs(sourcedir)
|
||||
|
||||
cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
|
||||
cuda_impl = SubstituteTemplate(_PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE, {"name": name})
|
||||
cuda_source = SubstituteTemplate(
|
||||
_PYTORCH_CUDA_TEMPLATE,
|
||||
{
|
||||
"includes": _PYTORCH_GROUPED_GEMM_INCLUDES,
|
||||
"declaration": op.rt_module.emit(),
|
||||
"procedural_name": op.procedural_name(),
|
||||
"impl": cuda_impl,
|
||||
"torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element],
|
||||
},
|
||||
)
|
||||
with open(cuda_file, "w") as outfile:
|
||||
outfile.write(cuda_source)
|
||||
|
||||
cpp_file = os.path.join(sourcedir, name + ".cpp")
|
||||
cpp_source = SubstituteTemplate(
|
||||
_PYTORCH_GROUPED_GEMM_CPP_TEMPLATE,
|
||||
{"name": name, "description": f"CUTLASS {op.procedural_name()} grouped GEMM"},
|
||||
)
|
||||
with open(cpp_file, "w") as outfile:
|
||||
outfile.write(cpp_source)
|
||||
|
||||
_generate_setup(name, sourcedir)
|
||||
|
||||
if jit:
|
||||
return _jit(name, cc, cpp_file, cuda_file)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _pytorch_conv2d(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
|
||||
"""
|
||||
Generates source for building a PyTorch CUDA module that leverages the CUTLASS Conv2d
|
||||
specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
|
||||
compiled, loaded, and returned.
|
||||
|
||||
:param op: operation to emit in the module
|
||||
:param name: name of the module to generate
|
||||
:type name: str
|
||||
:param cc: compute capability of the device the module should target
|
||||
:type cc: int
|
||||
:param jit: whether the module should be just-in-time compiled
|
||||
:type jit: bool
|
||||
:param sourcedir: directory to which generated source files should be written
|
||||
:type sourcedir: str
|
||||
|
||||
Note that the when conv kind is `dgrad` or `wgrad`, the size of the input `(N, C, H, W)` or
|
||||
weight `(K, C, R, S)` should be provided. This is because there are multiple valid solutions
|
||||
for H/W/R/S given the same P/Q.
|
||||
|
||||
:return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
|
||||
"""
|
||||
if sourcedir != "" and not os.path.isdir(sourcedir):
|
||||
os.makedirs(sourcedir)
|
||||
cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
|
||||
extra_kw = {}
|
||||
if op.conv_kind == ConvKind.Fprop:
|
||||
impl_template = _PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x
|
||||
cpp_template = _PYTORCH_CONV2D_FPROP_CPP_TEMPLATE
|
||||
elif op.conv_kind == ConvKind.Dgrad:
|
||||
impl_template = _PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x
|
||||
cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE
|
||||
elif op.conv_kind == ConvKind.Wgrad:
|
||||
impl_template = _PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x
|
||||
cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE
|
||||
extra_kw["conv_kind_name"] = ConvKindNames[op.conv_kind].capitalize()
|
||||
extra_kw["torch_type_C"] = _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element]
|
||||
cuda_impl = SubstituteTemplate(impl_template, {"name": name, **extra_kw})
|
||||
cuda_source = SubstituteTemplate(
|
||||
_PYTORCH_CUDA_TEMPLATE,
|
||||
{
|
||||
"includes": _PYTORCH_CONV2D_INCLUDES,
|
||||
"declaration": op.rt_module.emit(),
|
||||
"procedural_name": op.procedural_name(),
|
||||
"impl": cuda_impl,
|
||||
"torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element],
|
||||
},
|
||||
)
|
||||
with open(cuda_file, "w") as outfile:
|
||||
outfile.write(cuda_source)
|
||||
|
||||
cpp_file = os.path.join(sourcedir, name + ".cpp")
|
||||
cpp_source = SubstituteTemplate(
|
||||
cpp_template,
|
||||
{"name": name, "description": f"CUTLASS {op.procedural_name()} Conv2d"},
|
||||
)
|
||||
with open(cpp_file, "w") as outfile:
|
||||
outfile.write(cpp_source)
|
||||
|
||||
_generate_setup(name, sourcedir)
|
||||
|
||||
if jit:
|
||||
return _jit(name, cc, cpp_file, cuda_file)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def pytorch(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
|
||||
"""
|
||||
Generates source for building a PyTorch CUDA module that leverages the CUTLASS kernel
|
||||
specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
|
||||
compiled, loaded, and returned.
|
||||
|
||||
The result of this method is files within ``sourcedir`` that can be used for building
|
||||
a PyTorch module.
|
||||
|
||||
:param op: operation to emit in the module
|
||||
:param name: name of the module to generate
|
||||
:type name: str
|
||||
:param cc: compute capability of the device the module should target
|
||||
:type cc: int
|
||||
:param jit: whether the module should be just-in-time compiled
|
||||
:type jit: bool
|
||||
:param sourcedir: directory to which generated source files should be written
|
||||
:type sourcedir: str
|
||||
|
||||
:return: loaded PyTorch module (if ``jit=True``) or None
|
||||
"""
|
||||
device_op = op.device_op()
|
||||
if isinstance(op, GemmOperationUniversal):
|
||||
return _pytorch_gemm(device_op, name, cc, jit, sourcedir)
|
||||
elif isinstance(op, GemmOperationGrouped):
|
||||
return _pytorch_grouped_gemm(device_op, name, cc, jit, sourcedir)
|
||||
elif isinstance(op, Conv2dOperation):
|
||||
return _pytorch_conv2d(device_op, name, cc, jit, sourcedir)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Operation type {type(op)} is not currently supported for PyTorch emission."
|
||||
)
|
||||
Reference in New Issue
Block a user