[CPU] Refactor CPU unquantized linear (#24150)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
@ -111,6 +111,49 @@ def onednn_int8_gemm_test_helper(primitive_cache_size: int,
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
|
||||
def onednn_gemm_test_helper(primitive_cache_size: int,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
use_bias: bool,
|
||||
use_stride: bool,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
device: str = "cpu"):
|
||||
if use_stride:
|
||||
a = torch.rand((m, 2 * k), dtype=dtype, device=device) * 1.5
|
||||
a = a[:, :k]
|
||||
else:
|
||||
a = torch.rand((m, k), dtype=dtype, device=device) * 1.5
|
||||
|
||||
b = torch.rand((n, k), dtype=dtype, device=device) * 1.5
|
||||
|
||||
if use_bias:
|
||||
bias = torch.rand((n, ), device=device, dtype=dtype) * 5
|
||||
bias_f32 = bias.float()
|
||||
else:
|
||||
bias = None
|
||||
bias_f32 = None
|
||||
|
||||
handler = ops.create_onednn_mm(
|
||||
b.t(),
|
||||
primitive_cache_size,
|
||||
)
|
||||
|
||||
out = ops.onednn_mm(handler, a, bias)
|
||||
baseline = torch.nn.functional.linear(a.float(), b.float(),
|
||||
bias_f32).to(dtype=a.dtype)
|
||||
|
||||
torch.testing.assert_close(out, baseline)
|
||||
|
||||
if use_bias:
|
||||
# To test runtime bias setting
|
||||
out = ops.onednn_mm(handler, a, None)
|
||||
baseline = torch.nn.functional.linear(a.float(), b.float(),
|
||||
None).to(dtype=a.dtype)
|
||||
|
||||
torch.testing.assert_close(out, baseline)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n,k", NK_FACTORS)
|
||||
@pytest.mark.parametrize("m_list", M_FACTORS)
|
||||
@pytest.mark.parametrize("per_tensor_a_scale", [True, False])
|
||||
@ -142,3 +185,30 @@ def test_onednn_int8_scaled_gemm(
|
||||
use_azp=use_azp,
|
||||
out_dtype=output_type,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n,k", NK_FACTORS)
|
||||
@pytest.mark.parametrize("m_list", M_FACTORS)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("use_stride", [True, False])
|
||||
@pytest.mark.parametrize("dtype", DTYPE)
|
||||
@pytest.mark.parametrize("primitive_cache_size", CACHE_SIZES)
|
||||
def test_onednn_gemm(
|
||||
n: int,
|
||||
k: int,
|
||||
m_list: tuple[int],
|
||||
use_bias: bool,
|
||||
use_stride: bool,
|
||||
dtype: torch.dtype,
|
||||
primitive_cache_size: int,
|
||||
):
|
||||
for m in m_list:
|
||||
onednn_gemm_test_helper(
|
||||
primitive_cache_size=primitive_cache_size,
|
||||
m=m,
|
||||
n=n,
|
||||
k=k,
|
||||
use_bias=use_bias,
|
||||
use_stride=use_stride,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user