[Kernel][Misc] register ops to prevent graph breaks (#6917)
Co-authored-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
@ -267,3 +267,15 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
||||
c10::SymInt size_k, c10::SymInt size_n,
|
||||
int64_t num_bits) {
|
||||
int const pack_factor = 32 / num_bits;
|
||||
auto options = torch::TensorOptions()
|
||||
.dtype(b_q_weight.dtype())
|
||||
.device(b_q_weight.device());
|
||||
return torch::empty_symint(
|
||||
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
||||
options);
|
||||
}
|
||||
|
||||
@ -342,3 +342,15 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
||||
torch::Tensor& perm, c10::SymInt size_k,
|
||||
c10::SymInt size_n, int64_t num_bits) {
|
||||
int const pack_factor = 32 / num_bits;
|
||||
auto options = torch::TensorOptions()
|
||||
.dtype(b_q_weight.dtype())
|
||||
.device(b_q_weight.device());
|
||||
return torch::empty_symint(
|
||||
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
||||
options);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user