[Kernel][Misc] Use TORCH_LIBRARY instead of PYBIND11_MODULE for custom ops (#5047)
This commit is contained in:
@ -88,7 +88,7 @@ void fused_add_rms_norm_impl(scalar_t* __restrict__ input,
|
||||
} // namespace
|
||||
|
||||
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||
float epsilon) {
|
||||
double epsilon) {
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
@ -102,7 +102,7 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||
}
|
||||
|
||||
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
|
||||
torch::Tensor& weight, float epsilon) {
|
||||
torch::Tensor& weight, double epsilon) {
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user