[Bugfix] Don't build machete on cuda <12.0 (#7757)

This commit is contained in:
Lucas Wilkinson
2024-08-22 08:28:52 -04:00
committed by GitHub
parent 4f419c00a6
commit 55d63b1211
2 changed files with 48 additions and 28 deletions

View File

@ -37,9 +37,13 @@ static auto scalar_type_dispatch(ScalarType const& type, Fn fn) {
//
std::vector<std::string> supported_schedules(ScalarTypeTorchPtr const& btype) {
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
return scalar_type_dispatch(*btype, [&](auto BType) {
return GemmDispatcher<half_t, decltype(BType)>::supported_schedules();
});
#else
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
#endif
}
torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
@ -50,6 +54,7 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
c10::optional<torch::Tensor> const& C,
c10::optional<double> alpha, c10::optional<double> beta,
c10::optional<std::string> schedule) {
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
auto args = PyTorchArguments{.A = A,
.B = B,
.scales = scales,
@ -67,13 +72,20 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
return GemmDispatcher<ComputeType, decltype(BType)>::dispatch(args);
});
});
#else
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
#endif
}
torch::Tensor prepack_B(torch::Tensor const& B,
ScalarTypeTorchPtr const& btype) {
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
return scalar_type_dispatch(*btype, [&](auto BType) {
return PrepackBDispatcher<half_t, decltype(BType), half_t>::dispatch(B);
});
#else
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
#endif
}
}; // namespace machete