[Kernel][Misc] Use TORCH_LIBRARY instead of PYBIND11_MODULE for custom ops (#5047)

This commit is contained in:
bnellnm
2024-06-09 16:23:30 -04:00
committed by GitHub
parent 5d7e3d0176
commit 5467ac3196
55 changed files with 833 additions and 451 deletions

View File

@ -88,7 +88,7 @@ void fused_add_rms_norm_impl(scalar_t* __restrict__ input,
} // namespace
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
float epsilon) {
double epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
@ -102,7 +102,7 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
}
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, float epsilon) {
torch::Tensor& weight, double epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;