releaase 2.11 (#703)
This commit is contained in:
@ -85,18 +85,11 @@ You can run the PyCUTLASS on NGC PyTorch container.
|
||||
```shell
|
||||
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:22.09-py3
|
||||
```
|
||||
PyCUTLASS requires additional dependency Boost C++ library, which can be installed with
|
||||
```bash
|
||||
apt-get update
|
||||
apt-get -y install libboost-all-dev
|
||||
```
|
||||
|
||||
|
||||
|
||||
### Environment variables
|
||||
PyCUTLASSS requires two environment variables:
|
||||
* `CUTLASS_PATH`: the root directory of CUTLASS
|
||||
* `CUDA_INSTALL_PATH`: the directory where cuda toolkit is installed
|
||||
* `CUTLASS_PATH`: the root directory of CUTLASS. You can set this from the location at which you cloned CUTLASS via: `export CUTLASS_PATH=$(pwd)`.
|
||||
* `CUDA_INSTALL_PATH`: the directory where cuda toolkit is installed. If running in bash with `nvcc` installed under a CUDA toolkit, you can set this to the location of your `nvcc` installation via: `export CUDA_INSTALL_PATH=$(which nvcc | awk -F'/bin/nvcc' '{print $1}')`
|
||||
|
||||
After setting these two environment variables, PyCUTLASS can be installed with
|
||||
```shell
|
||||
|
||||
@ -38,6 +38,7 @@
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/params_universal_base.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
@ -104,16 +105,12 @@ public:
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
struct Arguments : UniversalArgumentsBase {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmUniversalMode mode;
|
||||
GemmCoord problem_size;
|
||||
int batch_count;
|
||||
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor;
|
||||
|
||||
void const * ptr_A;
|
||||
@ -124,7 +121,6 @@ public:
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
typename LayoutA::Stride stride_a;
|
||||
typename LayoutB::Stride stride_b;
|
||||
@ -145,8 +141,6 @@ public:
|
||||
//
|
||||
|
||||
Arguments():
|
||||
mode(GemmUniversalMode::kGemm),
|
||||
batch_count(1),
|
||||
ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr),
|
||||
ptr_gather_A_indices(nullptr),
|
||||
ptr_gather_B_indices(nullptr),
|
||||
@ -174,12 +168,10 @@ public:
|
||||
int const *ptr_gather_B_indices = nullptr,
|
||||
int const *ptr_scatter_D_indices = nullptr
|
||||
):
|
||||
mode(mode),
|
||||
problem_size(problem_size),
|
||||
batch_count(batch_count),
|
||||
UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D),
|
||||
epilogue_visitor(epilogue_visitor),
|
||||
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
|
||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
|
||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C),
|
||||
stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d),
|
||||
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
|
||||
ptr_scatter_D_indices(ptr_scatter_D_indices) {
|
||||
@ -212,12 +204,10 @@ public:
|
||||
int const *ptr_gather_B_indices = nullptr,
|
||||
int const *ptr_scatter_D_indices = nullptr
|
||||
):
|
||||
mode(mode),
|
||||
problem_size(problem_size),
|
||||
batch_count(batch_count),
|
||||
UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D),
|
||||
epilogue_visitor(epilogue_visitor),
|
||||
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
|
||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
|
||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C),
|
||||
lda(lda), ldb(ldb), ldc(ldc), ldd(ldd),
|
||||
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
|
||||
ptr_scatter_D_indices(ptr_scatter_D_indices) {
|
||||
@ -248,11 +238,19 @@ public:
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
struct Params : UniversalParamsBase<
|
||||
ThreadblockSwizzle,
|
||||
ThreadblockShape,
|
||||
ElementA,
|
||||
ElementB,
|
||||
ElementC> {
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
using ParamsBase = UniversalParamsBase<
|
||||
ThreadblockSwizzle,
|
||||
ThreadblockShape,
|
||||
ElementA,
|
||||
ElementB,
|
||||
ElementC>;
|
||||
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
@ -261,10 +259,6 @@ public:
|
||||
|
||||
typename EpilogueVisitor::Params epilogue_visitor;
|
||||
|
||||
GemmUniversalMode mode;
|
||||
int batch_count;
|
||||
int gemm_k_size;
|
||||
|
||||
void * ptr_A;
|
||||
void * ptr_B;
|
||||
void * ptr_C;
|
||||
@ -273,7 +267,6 @@ public:
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
int * ptr_gather_A_indices;
|
||||
int * ptr_gather_B_indices;
|
||||
@ -285,47 +278,21 @@ public:
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
swizzle_log_tile(0),
|
||||
params_A(0),
|
||||
params_B(0),
|
||||
params_C(0),
|
||||
params_D(0),
|
||||
batch_count(0),
|
||||
gemm_k_size(0),
|
||||
mode(cutlass::gemm::GemmUniversalMode::kGemm),
|
||||
ptr_A(nullptr),
|
||||
ptr_B(nullptr),
|
||||
ptr_C(nullptr),
|
||||
ptr_D(nullptr),
|
||||
batch_stride_A(0),
|
||||
batch_stride_B(0),
|
||||
batch_stride_C(0),
|
||||
batch_stride_D(0),
|
||||
ptr_gather_A_indices(nullptr),
|
||||
ptr_gather_B_indices(nullptr),
|
||||
ptr_scatter_D_indices(nullptr),
|
||||
semaphore(nullptr) { }
|
||||
/// Default constructor
|
||||
Params() = default;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||
int gemm_k_size,
|
||||
void *workspace = nullptr
|
||||
int device_sms,
|
||||
int sm_occupancy
|
||||
):
|
||||
problem_size(args.problem_size),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
||||
ParamsBase(args, device_sms, sm_occupancy),
|
||||
params_A(args.lda ? make_Coord_with_padding<LayoutA::kStrideRank>(args.lda) : args.stride_a),
|
||||
params_B(args.ldb ? make_Coord_with_padding<LayoutB::kStrideRank>(args.ldb) : args.stride_b),
|
||||
params_C(args.ldc ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldc) : args.stride_c),
|
||||
params_D(args.ldd ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldd) : args.stride_d),
|
||||
epilogue_visitor(args.epilogue_visitor),
|
||||
mode(args.mode),
|
||||
batch_count(args.batch_count),
|
||||
gemm_k_size(gemm_k_size),
|
||||
ptr_A(const_cast<void *>(args.ptr_A)),
|
||||
ptr_B(const_cast<void *>(args.ptr_B)),
|
||||
ptr_C(const_cast<void *>(args.ptr_C)),
|
||||
@ -333,11 +300,9 @@ public:
|
||||
batch_stride_A(args.batch_stride_A),
|
||||
batch_stride_B(args.batch_stride_B),
|
||||
batch_stride_C(args.batch_stride_C),
|
||||
batch_stride_D(args.batch_stride_D),
|
||||
ptr_gather_A_indices(const_cast<int *>(args.ptr_gather_A_indices)),
|
||||
ptr_gather_B_indices(const_cast<int *>(args.ptr_gather_B_indices)),
|
||||
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices)),
|
||||
semaphore(static_cast<int *>(workspace)) {
|
||||
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices)) {
|
||||
|
||||
}
|
||||
|
||||
@ -358,7 +323,6 @@ public:
|
||||
batch_stride_A = args.batch_stride_A;
|
||||
batch_stride_B = args.batch_stride_B;
|
||||
batch_stride_C = args.batch_stride_C;
|
||||
batch_stride_D = args.batch_stride_D;
|
||||
|
||||
epilogue_visitor = args.epilogue_visitor;
|
||||
|
||||
@ -466,12 +430,6 @@ public:
|
||||
return can_implement(args.problem_size);
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
@ -38,11 +38,19 @@
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
#include "cutlass/conv/threadblock/threadblock_swizzle.h"
|
||||
|
||||
#include <boost/core/demangle.hpp>
|
||||
#include <cxxabi.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
std::string demangle(const char* mangled_name) {
|
||||
std::size_t len = 0;
|
||||
int status = 0;
|
||||
std::unique_ptr<char> ptr(
|
||||
__cxxabiv1::__cxa_demangle(mangled_name, nullptr, &len, &status));
|
||||
return ptr.get();
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void bind_identity_swizzle(py::module & m, std::string name) {
|
||||
py::class_<T>(m, name.c_str(),
|
||||
@ -80,7 +88,7 @@ void bind_identity_swizzle(py::module & m, std::string name) {
|
||||
py::arg("tiled_shape"),
|
||||
R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc")
|
||||
.def("tag", [](const T & swizzle){
|
||||
return boost::core::demangle(typeid(T).name());
|
||||
return demangle(typeid(T).name());
|
||||
}, R"pbdoc(Returns the c++ name of the swizzling for code emittion)pbdoc");
|
||||
}
|
||||
|
||||
@ -101,7 +109,7 @@ void bind_swizzle(py::module & m, std::string name, std::string doc) {
|
||||
py::arg("tiled_shape"),
|
||||
R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc")
|
||||
.def("tag", [](const T & swizzle){
|
||||
return boost::core::demangle(typeid(T).name());
|
||||
return demangle(typeid(T).name());
|
||||
}, R"pbdoc(Returns the c++ name of the swizzling for code emittion)pbdoc");
|
||||
}
|
||||
|
||||
@ -124,7 +132,7 @@ void bind_dgrad_swizzle(py::module & m, std::string name) {
|
||||
}, py::arg("tiled_shape"),
|
||||
R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc")
|
||||
.def("tag", [](const T & swizzle){
|
||||
return boost::core::demangle(typeid(T).name());
|
||||
return demangle(typeid(T).name());
|
||||
}, R"pbdoc(Returns the c++ name of the swizzling for code emittion)pbdoc");
|
||||
}
|
||||
|
||||
|
||||
@ -69,9 +69,12 @@ def get_gemm_arguments(epilogue_functor):
|
||||
|
||||
class _GemmArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
# Arguments from UniversalArgumentsBase
|
||||
("mode", ctypes.c_int),
|
||||
("problem_size", GemmCoord_),
|
||||
("batch_count", ctypes.c_int),
|
||||
("batch_stride_D", ctypes.c_longlong),
|
||||
# Remaining arguments
|
||||
("epilogue", _EpilogueOutputOpParams),
|
||||
("ptr_A", ctypes.c_void_p),
|
||||
("ptr_B", ctypes.c_void_p),
|
||||
@ -80,7 +83,6 @@ def get_gemm_arguments(epilogue_functor):
|
||||
("batch_stride_A", ctypes.c_longlong),
|
||||
("batch_stride_B", ctypes.c_longlong),
|
||||
("batch_stride_C", ctypes.c_longlong),
|
||||
("batch_stride_D", ctypes.c_longlong),
|
||||
("stride_a", ctypes.c_longlong),
|
||||
("stride_b", ctypes.c_longlong),
|
||||
("stride_c", ctypes.c_longlong),
|
||||
|
||||
@ -229,7 +229,7 @@ class GemmArguments(ArgumentBase):
|
||||
elif operand in ["c", "d"]:
|
||||
tensor_coord = problem_size.mn()
|
||||
else:
|
||||
raise ValueError("unknonw operand: " + operand)
|
||||
raise ValueError("unknown operand: " + operand)
|
||||
|
||||
layout = tensor_layout.packed(tensor_coord)
|
||||
|
||||
@ -245,22 +245,27 @@ class GemmArguments(ArgumentBase):
|
||||
)
|
||||
if self.gemm_mode == cutlass.gemm.Mode.Array:
|
||||
arguments = self.operation.argument_type(
|
||||
self.gemm_mode, problem_size_, self.batch_count, self.output_op,
|
||||
# Arguments from UniversalArgumentsBase
|
||||
self.gemm_mode, problem_size_, self.batch_count, 0,
|
||||
# Remaining arguments
|
||||
self.output_op,
|
||||
int(self.ptr_A_array_buffer.ptr),
|
||||
int(self.ptr_B_array_buffer.ptr),
|
||||
int(self.ptr_C_array_buffer.ptr),
|
||||
int(self.ptr_D_array_buffer.ptr),
|
||||
0, 0, 0, 0,
|
||||
0, 0, 0,
|
||||
self.lda, self.ldb, self.ldc, self.ldd,
|
||||
self.lda, self.ldb, self.ldc, self.ldd,
|
||||
0, 0, 0
|
||||
)
|
||||
else:
|
||||
arguments = self.operation.argument_type(
|
||||
self.gemm_mode, problem_size_, self.batch_count, self.output_op,
|
||||
# Arguments from UniversalArgumentsBase
|
||||
self.gemm_mode, problem_size_, self.batch_count, self.batched_stride_D,
|
||||
# Remaining arguments
|
||||
self.output_op,
|
||||
int(self.ptr_A), int(self.ptr_B), int(self.ptr_C), int(self.ptr_D),
|
||||
self.batched_stride_A, self.batched_stride_B, self.batched_stride_C,
|
||||
self.batched_stride_D,
|
||||
self.lda, self.ldb, self.ldc, self.ldd,
|
||||
self.lda, self.ldb, self.ldc, self.ldd,
|
||||
0, 0, 0
|
||||
@ -299,8 +304,7 @@ class GemmArguments(ArgumentBase):
|
||||
|
||||
arguments, grid_tiled_shape, gemm_k_size = self.arguments
|
||||
res_arg = self.operation.rt_module.get_args(
|
||||
ctypes.byref(arguments), ctypes.byref(grid_tiled_shape),
|
||||
gemm_k_size, ctypes.c_void_p(int(device_workspace)))
|
||||
ctypes.byref(arguments), ctypes.c_void_p(int(device_workspace)))
|
||||
host_workspace = bytearray(res_arg.contents)
|
||||
|
||||
device_workspace = None
|
||||
@ -582,10 +586,15 @@ extern "C" {
|
||||
}
|
||||
|
||||
// Get the params as byte array
|
||||
char* ${operation_name}_get_params(${operation_name}_base::Arguments* argument, \
|
||||
cutlass::gemm::GemmCoord* grid_tiled_shape, int gemm_k_size, int* workspace){
|
||||
char* ${operation_name}_get_params(${operation_name}_base::Arguments* argument, int* workspace){
|
||||
${operation_name}_base::Params* params;
|
||||
params = new ${operation_name}_base::Params(*argument, *grid_tiled_shape, gemm_k_size, workspace);
|
||||
params = new ${operation_name}_base::Params(*argument,
|
||||
-1, // SM count. Only used for stream-K
|
||||
-1 // Occupancy. Only used for stream-K
|
||||
);
|
||||
|
||||
// Semaphore holds the pointer to the workspace in the Params struct
|
||||
params->semaphore = workspace;
|
||||
|
||||
char *bytes = ((char*)(params));
|
||||
char *output = new char[sizeof(${operation_name}_base::Params)];
|
||||
|
||||
@ -116,13 +116,11 @@ DataTypeNames = {
|
||||
|
||||
DataTypeTag = {
|
||||
cutlass.dtype.b1: "cutlass::uint1b_t",
|
||||
cutlass.dtype.u2: "cutlass::uint2b_t",
|
||||
cutlass.dtype.u4: "cutlass::uint4b_t",
|
||||
cutlass.dtype.u8: "uint8_t",
|
||||
cutlass.dtype.u16: "uint16_t",
|
||||
cutlass.dtype.u32: "uint32_t",
|
||||
cutlass.dtype.u64: "uint64_t",
|
||||
cutlass.dtype.s2: "cutlass::int2b_t",
|
||||
cutlass.dtype.s4: "cutlass::int4b_t",
|
||||
cutlass.int8: "int8_t",
|
||||
cutlass.dtype.s16: "int16_t",
|
||||
@ -138,13 +136,11 @@ DataTypeTag = {
|
||||
cutlass.dtype.cf32: "cutlass::complex<float>",
|
||||
cutlass.dtype.ctf32: "cutlass::complex<cutlass::tfloat32_t>",
|
||||
cutlass.dtype.cf64: "cutlass::complex<double>",
|
||||
cutlass.dtype.cu2: "cutlass::complex<cutlass::uint2b_t>",
|
||||
cutlass.dtype.cu4: "cutlass::complex<cutlass::uint4b_t>",
|
||||
cutlass.dtype.cu8: "cutlass::complex<cutlass::uint8_t>",
|
||||
cutlass.dtype.cu16: "cutlass::complex<cutlass::uint16_t>",
|
||||
cutlass.dtype.cu32: "cutlass::complex<cutlass::uint32_t>",
|
||||
cutlass.dtype.cu64: "cutlass::complex<cutlass::uint64_t>",
|
||||
cutlass.dtype.cs2: "cutlass::complex<cutlass::int2b_t>",
|
||||
cutlass.dtype.cs4: "cutlass::complex<cutlass::int4b_t>",
|
||||
cutlass.dtype.cs8: "cutlass::complex<cutlass::int8_t>",
|
||||
cutlass.dtype.cs16: "cutlass::complex<cutlass::int16_t>",
|
||||
|
||||
2
tools/library/scripts/pycutlass/test/example/run_all_example.sh
Normal file → Executable file
2
tools/library/scripts/pycutlass/test/example/run_all_example.sh
Normal file → Executable file
@ -1,4 +1,4 @@
|
||||
pushd $CUTLASS_PATH/examples/40_cutlass_py/
|
||||
pushd $CUTLASS_PATH/examples/40_cutlass_py/customizable
|
||||
|
||||
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user