[Bugfix] Fix support for dimension like integers and ScalarType (#9299)

This commit is contained in:
bnellnm
2024-10-17 15:08:34 -04:00
committed by GitHub
parent 0f41fbe5a3
commit eca2c5f7c0
22 changed files with 427 additions and 677 deletions

View File

@ -80,7 +80,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& b_zeros,
torch::Tensor& g_idx, torch::Tensor& perm,
torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
vllm::ScalarTypeId const b_q_type_id,
int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool has_zp) {
TORCH_CHECK_NOT_IMPLEMENTED(false,
@ -2132,22 +2132,23 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& b_zeros,
torch::Tensor& g_idx, torch::Tensor& perm,
torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
vllm::ScalarTypeId const& b_q_type_id,
int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool has_zp,
bool use_fp32_reduce) {
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
if (has_zp) {
TORCH_CHECK(*b_q_type == vllm::kU4 || *b_q_type == vllm::kU8,
"b_q_type must be u4 or u8 when has_zp = True. Got = ",
b_q_type->str());
TORCH_CHECK(
b_q_type == vllm::kU4 || b_q_type == vllm::kU8,
"b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str());
} else {
TORCH_CHECK(
*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
b_q_type->str());
b_q_type.str());
}
int pack_factor = 32 / b_q_type->size_bits();
int pack_factor = 32 / b_q_type.size_bits();
// Verify A
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
@ -2279,7 +2280,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
@ -2288,7 +2289,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
} else {
@ -2302,4 +2303,4 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("gptq_marlin_gemm", &gptq_marlin_gemm);
}
}

View File

@ -38,9 +38,10 @@ static auto scalar_type_dispatch(ScalarType const& type, Fn fn) {
// Interface
//
std::vector<std::string> supported_schedules(ScalarTypeTorchPtr const& btype) {
std::vector<std::string> supported_schedules(ScalarTypeId const btype_id) {
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
return scalar_type_dispatch(*btype, [&](auto BType) {
vllm::ScalarType b_type = ScalarType::from_id(btype_id);
return scalar_type_dispatch(b_type, [&](auto BType) {
return GemmDispatcher<half_t, decltype(BType)>::supported_schedules();
});
#else
@ -49,7 +50,7 @@ std::vector<std::string> supported_schedules(ScalarTypeTorchPtr const& btype) {
}
torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
ScalarTypeTorchPtr const& btype,
ScalarTypeId const btype_id,
c10::optional<torch::Tensor> const& scales,
c10::optional<torch::Tensor> const& zeros,
c10::optional<int64_t> group_size,
@ -57,6 +58,7 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
c10::optional<double> alpha, c10::optional<double> beta,
c10::optional<std::string> schedule) {
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
ScalarType const btype = ScalarType::from_id(btype_id);
auto args = PyTorchArguments{.A = A,
.B = B,
.scales = scales,
@ -67,7 +69,7 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
.beta = beta,
.schedule = schedule};
return scalar_type_dispatch(*btype, [&](auto BType) {
return scalar_type_dispatch(btype, [&](auto BType) {
return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(
A.scalar_type(), "machete_gemm", [&] {
using ComputeType = equivalent_cutlass_type_t<scalar_t>;
@ -79,9 +81,9 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
#endif
}
torch::Tensor prepack_B(torch::Tensor const& B,
vllm::ScalarTypeTorchPtr const& btype) {
return scalar_type_dispatch(*btype, [&](auto BType) {
torch::Tensor prepack_B(torch::Tensor const& B, ScalarTypeId const btype_id) {
ScalarType const btype = ScalarType::from_id(btype_id);
return scalar_type_dispatch(btype, [&](auto BType) {
return PrepackBDispatcher<half_t, decltype(BType), half_t>::dispatch(B);
});
}

View File

@ -89,7 +89,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta,
torch::Tensor& b_scales,
torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
vllm::ScalarTypeId const b_q_type_id,
int64_t size_m, int64_t size_n,
int64_t size_k) {
TORCH_CHECK_NOT_IMPLEMENTED(
@ -1029,13 +1029,14 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta,
torch::Tensor& b_scales,
torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
vllm::ScalarTypeId const b_q_type_id,
int64_t size_m, int64_t size_n,
int64_t size_k) {
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
// Verify num_bits
TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
"num_bits must be uint4b8 or uint8b128. Got = ", b_q_type->str());
int pack_factor = 32 / b_q_type->size_bits();
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
"num_bits must be uint4b8 or uint8b128. Got = ", b_q_type.str());
int pack_factor = 32 / b_q_type.size_bits();
// Verify M
TORCH_CHECK(size_m == a.size(0),
@ -1130,8 +1131,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
marlin_24::marlin_cuda_2_4(
a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(),
b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(),
b_q_type->size_bits(), groupsize, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_m, sms, max_par);
b_q_type.size_bits(), groupsize, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_m, sms, max_par);
return c;
}