[Bugfix] Fix support for dimension like integers and ScalarType (#9299)
This commit is contained in:
@ -80,7 +80,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
||||
torch::Tensor& g_idx, torch::Tensor& perm,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||
vllm::ScalarTypeId const b_q_type_id,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full, bool has_zp) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
@ -2132,22 +2132,23 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
||||
torch::Tensor& g_idx, torch::Tensor& perm,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||
vllm::ScalarTypeId const& b_q_type_id,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full, bool has_zp,
|
||||
bool use_fp32_reduce) {
|
||||
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(*b_q_type == vllm::kU4 || *b_q_type == vllm::kU8,
|
||||
"b_q_type must be u4 or u8 when has_zp = True. Got = ",
|
||||
b_q_type->str());
|
||||
TORCH_CHECK(
|
||||
b_q_type == vllm::kU4 || b_q_type == vllm::kU8,
|
||||
"b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
|
||||
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
|
||||
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
|
||||
b_q_type->str());
|
||||
b_q_type.str());
|
||||
}
|
||||
|
||||
int pack_factor = 32 / b_q_type->size_bits();
|
||||
int pack_factor = 32 / b_q_type.size_bits();
|
||||
|
||||
// Verify A
|
||||
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
|
||||
@ -2279,7 +2280,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
|
||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
||||
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
|
||||
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
@ -2288,7 +2289,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
||||
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
|
||||
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
|
||||
} else {
|
||||
@ -2302,4 +2303,4 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("gptq_marlin_gemm", &gptq_marlin_gemm);
|
||||
}
|
||||
}
|
||||
|
||||
@ -38,9 +38,10 @@ static auto scalar_type_dispatch(ScalarType const& type, Fn fn) {
|
||||
// Interface
|
||||
//
|
||||
|
||||
std::vector<std::string> supported_schedules(ScalarTypeTorchPtr const& btype) {
|
||||
std::vector<std::string> supported_schedules(ScalarTypeId const btype_id) {
|
||||
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
|
||||
return scalar_type_dispatch(*btype, [&](auto BType) {
|
||||
vllm::ScalarType b_type = ScalarType::from_id(btype_id);
|
||||
return scalar_type_dispatch(b_type, [&](auto BType) {
|
||||
return GemmDispatcher<half_t, decltype(BType)>::supported_schedules();
|
||||
});
|
||||
#else
|
||||
@ -49,7 +50,7 @@ std::vector<std::string> supported_schedules(ScalarTypeTorchPtr const& btype) {
|
||||
}
|
||||
|
||||
torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
|
||||
ScalarTypeTorchPtr const& btype,
|
||||
ScalarTypeId const btype_id,
|
||||
c10::optional<torch::Tensor> const& scales,
|
||||
c10::optional<torch::Tensor> const& zeros,
|
||||
c10::optional<int64_t> group_size,
|
||||
@ -57,6 +58,7 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
|
||||
c10::optional<double> alpha, c10::optional<double> beta,
|
||||
c10::optional<std::string> schedule) {
|
||||
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
|
||||
ScalarType const btype = ScalarType::from_id(btype_id);
|
||||
auto args = PyTorchArguments{.A = A,
|
||||
.B = B,
|
||||
.scales = scales,
|
||||
@ -67,7 +69,7 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
|
||||
.beta = beta,
|
||||
.schedule = schedule};
|
||||
|
||||
return scalar_type_dispatch(*btype, [&](auto BType) {
|
||||
return scalar_type_dispatch(btype, [&](auto BType) {
|
||||
return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(
|
||||
A.scalar_type(), "machete_gemm", [&] {
|
||||
using ComputeType = equivalent_cutlass_type_t<scalar_t>;
|
||||
@ -79,9 +81,9 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
|
||||
#endif
|
||||
}
|
||||
|
||||
torch::Tensor prepack_B(torch::Tensor const& B,
|
||||
vllm::ScalarTypeTorchPtr const& btype) {
|
||||
return scalar_type_dispatch(*btype, [&](auto BType) {
|
||||
torch::Tensor prepack_B(torch::Tensor const& B, ScalarTypeId const btype_id) {
|
||||
ScalarType const btype = ScalarType::from_id(btype_id);
|
||||
return scalar_type_dispatch(btype, [&](auto BType) {
|
||||
return PrepackBDispatcher<half_t, decltype(BType), half_t>::dispatch(B);
|
||||
});
|
||||
}
|
||||
|
||||
@ -89,7 +89,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_meta,
|
||||
torch::Tensor& b_scales,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||
vllm::ScalarTypeId const b_q_type_id,
|
||||
int64_t size_m, int64_t size_n,
|
||||
int64_t size_k) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
@ -1029,13 +1029,14 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_meta,
|
||||
torch::Tensor& b_scales,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||
vllm::ScalarTypeId const b_q_type_id,
|
||||
int64_t size_m, int64_t size_n,
|
||||
int64_t size_k) {
|
||||
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
|
||||
// Verify num_bits
|
||||
TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
|
||||
"num_bits must be uint4b8 or uint8b128. Got = ", b_q_type->str());
|
||||
int pack_factor = 32 / b_q_type->size_bits();
|
||||
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
|
||||
"num_bits must be uint4b8 or uint8b128. Got = ", b_q_type.str());
|
||||
int pack_factor = 32 / b_q_type.size_bits();
|
||||
|
||||
// Verify M
|
||||
TORCH_CHECK(size_m == a.size(0),
|
||||
@ -1130,8 +1131,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
marlin_24::marlin_cuda_2_4(
|
||||
a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(),
|
||||
b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(),
|
||||
b_q_type->size_bits(), groupsize, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_m, sms, max_par);
|
||||
b_q_type.size_bits(), groupsize, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_m, sms, max_par);
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user