[Bugfix] Don't build machete on cuda <12.0 (#7757)
This commit is contained in:
@ -37,9 +37,13 @@ static auto scalar_type_dispatch(ScalarType const& type, Fn fn) {
|
||||
//
|
||||
|
||||
std::vector<std::string> supported_schedules(ScalarTypeTorchPtr const& btype) {
|
||||
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
|
||||
return scalar_type_dispatch(*btype, [&](auto BType) {
|
||||
return GemmDispatcher<half_t, decltype(BType)>::supported_schedules();
|
||||
});
|
||||
#else
|
||||
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
|
||||
#endif
|
||||
}
|
||||
|
||||
torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
|
||||
@ -50,6 +54,7 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
|
||||
c10::optional<torch::Tensor> const& C,
|
||||
c10::optional<double> alpha, c10::optional<double> beta,
|
||||
c10::optional<std::string> schedule) {
|
||||
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
|
||||
auto args = PyTorchArguments{.A = A,
|
||||
.B = B,
|
||||
.scales = scales,
|
||||
@ -67,13 +72,20 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
|
||||
return GemmDispatcher<ComputeType, decltype(BType)>::dispatch(args);
|
||||
});
|
||||
});
|
||||
#else
|
||||
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
|
||||
#endif
|
||||
}
|
||||
|
||||
torch::Tensor prepack_B(torch::Tensor const& B,
|
||||
ScalarTypeTorchPtr const& btype) {
|
||||
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
|
||||
return scalar_type_dispatch(*btype, [&](auto BType) {
|
||||
return PrepackBDispatcher<half_t, decltype(BType), half_t>::dispatch(B);
|
||||
});
|
||||
#else
|
||||
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
|
||||
#endif
|
||||
}
|
||||
|
||||
}; // namespace machete
|
||||
|
||||
Reference in New Issue
Block a user