diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index 582e4aae78..f3ec6b5035 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -134,10 +134,7 @@ def matmul_kernel_persistent( bias_ptrs = bias_ptr + offs_cn bias = tl.load(bias_ptrs, mask=offs_cn < N, other=0.0).to(tl.float32) accumulator += bias - if c_ptr.dtype.element_ty == tl.float8e4nv: - c = accumulator.to(tl.float8e4nv) - else: - c = accumulator.to(tl.float16) + c = accumulator.to(c_ptr.dtype.element_ty) tl.store(c_ptrs, c, mask=c_mask)