Allow changing epsilon parameter in RMS norm kernel (#1112)
This commit is contained in:
@ -43,7 +43,8 @@ using Layout = cutlass::layout::RowMajor;
|
||||
void rmsnorm_host(cutlass::MatrixCoord tensor_size,
|
||||
cutlass::TensorRef<ElementType, Layout> output,
|
||||
cutlass::TensorRef<ElementType, Layout> input,
|
||||
cutlass::TensorRef<ElementType, Layout> weight) {
|
||||
cutlass::TensorRef<ElementType, Layout> weight,
|
||||
float epsilon) {
|
||||
const int M = tensor_size.row();
|
||||
const int N = tensor_size.column();
|
||||
|
||||
@ -56,7 +57,7 @@ void rmsnorm_host(cutlass::MatrixCoord tensor_size,
|
||||
}
|
||||
|
||||
float sq_mean = square_sum / (float)N;
|
||||
float sqrt_var = cutlass::fast_sqrt(sq_mean + (float)1e-6);
|
||||
float sqrt_var = cutlass::fast_sqrt(sq_mean + epsilon);
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
float inp = static_cast<float>(input.at({m, n}));
|
||||
@ -91,9 +92,9 @@ void run_test(int M, int N) {
|
||||
input.sync_device();
|
||||
weight.sync_device();
|
||||
|
||||
rmsnorm_host({M, N}, output_ref.host_ref(), input.host_ref(), weight.host_ref());
|
||||
rmsnorm_host({M, N}, output_ref.host_ref(), input.host_ref(), weight.host_ref(), (float)1e-5);
|
||||
cutlass::rmsnorm({M, N}, output.device_ref(),
|
||||
input.device_ref(), weight.device_ref(), NULL);
|
||||
input.device_ref(), weight.device_ref(), NULL, (float)1e-5);
|
||||
|
||||
output.sync_host();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user