[CPU] Refactor CPU unquantized linear (#24150)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
Li, Jiang
2025-09-04 14:28:45 +08:00
committed by GitHub
parent cb55ad86fe
commit 57b1ce94f7
9 changed files with 466 additions and 26 deletions

View File

@ -111,6 +111,49 @@ def onednn_int8_gemm_test_helper(primitive_cache_size: int,
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
def onednn_gemm_test_helper(primitive_cache_size: int,
m: int,
n: int,
k: int,
use_bias: bool,
use_stride: bool,
dtype: torch.dtype = torch.bfloat16,
device: str = "cpu"):
if use_stride:
a = torch.rand((m, 2 * k), dtype=dtype, device=device) * 1.5
a = a[:, :k]
else:
a = torch.rand((m, k), dtype=dtype, device=device) * 1.5
b = torch.rand((n, k), dtype=dtype, device=device) * 1.5
if use_bias:
bias = torch.rand((n, ), device=device, dtype=dtype) * 5
bias_f32 = bias.float()
else:
bias = None
bias_f32 = None
handler = ops.create_onednn_mm(
b.t(),
primitive_cache_size,
)
out = ops.onednn_mm(handler, a, bias)
baseline = torch.nn.functional.linear(a.float(), b.float(),
bias_f32).to(dtype=a.dtype)
torch.testing.assert_close(out, baseline)
if use_bias:
# To test runtime bias setting
out = ops.onednn_mm(handler, a, None)
baseline = torch.nn.functional.linear(a.float(), b.float(),
None).to(dtype=a.dtype)
torch.testing.assert_close(out, baseline)
@pytest.mark.parametrize("n,k", NK_FACTORS)
@pytest.mark.parametrize("m_list", M_FACTORS)
@pytest.mark.parametrize("per_tensor_a_scale", [True, False])
@ -142,3 +185,30 @@ def test_onednn_int8_scaled_gemm(
use_azp=use_azp,
out_dtype=output_type,
)
@pytest.mark.parametrize("n,k", NK_FACTORS)
@pytest.mark.parametrize("m_list", M_FACTORS)
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("use_stride", [True, False])
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("primitive_cache_size", CACHE_SIZES)
def test_onednn_gemm(
n: int,
k: int,
m_list: tuple[int],
use_bias: bool,
use_stride: bool,
dtype: torch.dtype,
primitive_cache_size: int,
):
for m in m_list:
onednn_gemm_test_helper(
primitive_cache_size=primitive_cache_size,
m=m,
n=n,
k=k,
use_bias=use_bias,
use_stride=use_stride,
dtype=dtype,
)