Allow changing epsilon parameter in RMS norm kernel (#1112)

This commit is contained in:
masahi
2023-10-03 09:40:28 +09:00
committed by GitHub
parent 26986bbc60
commit ff61a49dd1
2 changed files with 13 additions and 11 deletions

View File

@ -43,7 +43,8 @@ using Layout = cutlass::layout::RowMajor;
void rmsnorm_host(cutlass::MatrixCoord tensor_size,
cutlass::TensorRef<ElementType, Layout> output,
cutlass::TensorRef<ElementType, Layout> input,
cutlass::TensorRef<ElementType, Layout> weight) {
cutlass::TensorRef<ElementType, Layout> weight,
float epsilon) {
const int M = tensor_size.row();
const int N = tensor_size.column();
@ -56,7 +57,7 @@ void rmsnorm_host(cutlass::MatrixCoord tensor_size,
}
float sq_mean = square_sum / (float)N;
float sqrt_var = cutlass::fast_sqrt(sq_mean + (float)1e-6);
float sqrt_var = cutlass::fast_sqrt(sq_mean + epsilon);
for (int n = 0; n < N; ++n) {
float inp = static_cast<float>(input.at({m, n}));
@ -91,9 +92,9 @@ void run_test(int M, int N) {
input.sync_device();
weight.sync_device();
rmsnorm_host({M, N}, output_ref.host_ref(), input.host_ref(), weight.host_ref());
rmsnorm_host({M, N}, output_ref.host_ref(), input.host_ref(), weight.host_ref(), (float)1e-5);
cutlass::rmsnorm({M, N}, output.device_ref(),
input.device_ref(), weight.device_ref(), NULL);
input.device_ref(), weight.device_ref(), NULL, (float)1e-5);
output.sync_host();